67
.gitignore
vendored
Normal file
67
.gitignore
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# 가상환경 (용량 크므로 제외)
|
||||
.venv/
|
||||
.venv_client/
|
||||
|
||||
# 모델 파일 (용량 매우 큼)
|
||||
*.pt
|
||||
*.pth
|
||||
*.onnx
|
||||
*.gz
|
||||
|
||||
# 로그
|
||||
X-AnyLabeling-Server/logs/
|
||||
*.log
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
|
||||
# 캐시
|
||||
.cache/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# 임시 파일
|
||||
create_shortcuts.ps1
|
||||
|
||||
# 클로드 개인 설정/메모리
|
||||
.claude/
|
||||
|
||||
# X-AnyLabeling-Server (별도 git repo, submodule 아님)
|
||||
X-AnyLabeling-Server/
|
||||
|
||||
# 용량 큰 데이터/결과 폴더
|
||||
drone_2cm/
|
||||
output/
|
||||
runs/
|
||||
data/
|
||||
|
||||
# 용량 큰 원본 영상 파일
|
||||
*.tif
|
||||
경부선_2-2구간_미션33-34(5cm).png
|
||||
|
||||
# 논문/문서 파일
|
||||
*.pdf
|
||||
*.txt
|
||||
|
||||
# 미정리 툴 (작업 중)
|
||||
tools/render_polygons_rainbow.py
|
||||
tools/render_yolo_labels.py
|
||||
tools/sam3_sleeper_detect.py
|
||||
tools/tif_hollow_detect.py
|
||||
|
||||
# 기타
|
||||
src/
|
||||
PRPs/
|
||||
nul
|
||||
triton-*.whl
|
||||
pyproject.toml
|
||||
Project_Status.md
|
||||
.usage/
|
||||
65
CLAUDE.md
Normal file
65
CLAUDE.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# CLAUDE.md
|
||||
|
||||
Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed.
|
||||
|
||||
**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment.
|
||||
|
||||
## 1. Think Before Coding
|
||||
|
||||
**Don't assume. Don't hide confusion. Surface tradeoffs.**
|
||||
|
||||
Before implementing:
|
||||
- State your assumptions explicitly. If uncertain, ask.
|
||||
- If multiple interpretations exist, present them - don't pick silently.
|
||||
- If a simpler approach exists, say so. Push back when warranted.
|
||||
- If something is unclear, stop. Name what's confusing. Ask.
|
||||
|
||||
## 2. Simplicity First
|
||||
|
||||
**Minimum code that solves the problem. Nothing speculative.**
|
||||
|
||||
- No features beyond what was asked.
|
||||
- No abstractions for single-use code.
|
||||
- No "flexibility" or "configurability" that wasn't requested.
|
||||
- No error handling for impossible scenarios.
|
||||
- If you write 200 lines and it could be 50, rewrite it.
|
||||
|
||||
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
|
||||
|
||||
## 3. Surgical Changes
|
||||
|
||||
**Touch only what you must. Clean up only your own mess.**
|
||||
|
||||
When editing existing code:
|
||||
- Don't "improve" adjacent code, comments, or formatting.
|
||||
- Don't refactor things that aren't broken.
|
||||
- Match existing style, even if you'd do it differently.
|
||||
- If you notice unrelated dead code, mention it - don't delete it.
|
||||
|
||||
When your changes create orphans:
|
||||
- Remove imports/variables/functions that YOUR changes made unused.
|
||||
- Don't remove pre-existing dead code unless asked.
|
||||
|
||||
The test: Every changed line should trace directly to the user's request.
|
||||
|
||||
## 4. Goal-Driven Execution
|
||||
|
||||
**Define success criteria. Loop until verified.**
|
||||
|
||||
Transform tasks into verifiable goals:
|
||||
- "Add validation" → "Write tests for invalid inputs, then make them pass"
|
||||
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
|
||||
- "Refactor X" → "Ensure tests pass before and after"
|
||||
|
||||
For multi-step tasks, state a brief plan:
|
||||
```
|
||||
1. [Step] → verify: [check]
|
||||
2. [Step] → verify: [check]
|
||||
3. [Step] → verify: [check]
|
||||
```
|
||||
|
||||
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
|
||||
|
||||
---
|
||||
|
||||
**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes.
|
||||
143
configs/railway_zone.json
Normal file
143
configs/railway_zone.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"_comment": "철도 구간 집중 검출용 카테고리 설정 (타일 9-24 기준)",
|
||||
"_conf_note": "conf: 클래스별 신뢰도 임계값. priority: 낮을수록 cross-class NMS에서 우선 보존(화면 위에 표시).",
|
||||
"_priority_order": "2=컨트롤박스 > 3=전철주 > 4=팬스 > 5=레일 > 6=침목 > 7=자갈(맨 밑)",
|
||||
"cross_class_nms_iou": 0.45,
|
||||
"categories": [
|
||||
{
|
||||
"name": "control_box",
|
||||
"name_kr": "컨트롤박스",
|
||||
"prompt": "small square gray metal box beside rail, compact trackside junction box, small near-square electrical enclosure on the ground, small cube-shaped equipment box next to track",
|
||||
"keywords": ["small square box", "junction box", "compact enclosure", "square metal box", "small cube box"],
|
||||
"color_bgr": [255, 255, 0],
|
||||
"conf": 0.15,
|
||||
"priority": 2
|
||||
},
|
||||
{
|
||||
"name": "catenary_pole",
|
||||
"name_kr": "전철주",
|
||||
"prompt": "railway catenary pole, overhead line pole, catenary mast, electric railway pole",
|
||||
"keywords": ["catenary pole", "overhead line pole", "catenary mast", "railway pole"],
|
||||
"color_bgr": [255, 0, 255],
|
||||
"conf": 0.25,
|
||||
"priority": 3
|
||||
},
|
||||
{
|
||||
"name": "fence",
|
||||
"name_kr": "팬스/울타리",
|
||||
"prompt": "railway fence, trackside fence, perimeter fence, chain link fence, railway boundary fence",
|
||||
"keywords": ["fence", "chain link", "perimeter fence", "boundary fence"],
|
||||
"color_bgr": [0, 255, 0],
|
||||
"conf": 0.50,
|
||||
"priority": 5
|
||||
},
|
||||
{
|
||||
"name": "railway",
|
||||
"name_kr": "철도 레일",
|
||||
"prompt": "two parallel steel rails, narrow metallic longitudinal beam, shiny steel rail line on sleeper",
|
||||
"keywords": ["railroad", "railway rail", "steel rail", "train track", "parallel rail"],
|
||||
"color_bgr": [0, 0, 255],
|
||||
"conf": 0.25,
|
||||
"priority": 4
|
||||
},
|
||||
{
|
||||
"name": "sleeper",
|
||||
"name_kr": "침목",
|
||||
"prompt": "rectangular concrete sleeper, dark crosswise tie, evenly spaced railroad tie perpendicular to rail, concrete slab between ballast",
|
||||
"keywords": ["sleeper", "railroad tie", "rail tie", "wooden tie", "concrete tie", "crosswise tie"],
|
||||
"color_bgr": [0, 128, 255],
|
||||
"conf": 0.20,
|
||||
"priority": 6
|
||||
},
|
||||
{
|
||||
"name": "ballast",
|
||||
"name_kr": "자갈도상",
|
||||
"prompt": "crushed gray stone gravel, railway ballast aggregate, coarse gravel track bed, angular stone between rails",
|
||||
"keywords": ["ballast", "gravel between rails", "track bed", "crushed stone", "aggregate"],
|
||||
"color_bgr": [30, 50, 100],
|
||||
"conf": 0.20,
|
||||
"priority": 7
|
||||
},
|
||||
{
|
||||
"name": "bracket",
|
||||
"name_kr": "브라켓/암",
|
||||
"prompt": "catenary bracket arm, horizontal cantilever arm, pole cross arm, overhead wire bracket",
|
||||
"keywords": ["bracket", "cantilever", "cross arm"],
|
||||
"color_bgr": [0, 80, 200],
|
||||
"conf": 0.20,
|
||||
"priority": 8
|
||||
},
|
||||
{
|
||||
"name": "bridge",
|
||||
"name_kr": "교량/교각",
|
||||
"prompt": "railway bridge, road bridge, overpass, viaduct, concrete bridge structure",
|
||||
"keywords": ["bridge", "overpass", "viaduct"],
|
||||
"color_bgr": [0, 165, 255],
|
||||
"conf": 0.25,
|
||||
"priority": 8
|
||||
},
|
||||
{
|
||||
"name": "retaining_wall",
|
||||
"name_kr": "방음벽/옹벽",
|
||||
"prompt": "railway retaining wall, noise barrier wall, sound barrier, concrete embankment wall",
|
||||
"keywords": ["retaining wall", "noise barrier", "sound barrier", "embankment wall"],
|
||||
"color_bgr": [60, 180, 60],
|
||||
"conf": 0.25,
|
||||
"priority": 9
|
||||
},
|
||||
{
|
||||
"name": "service_road",
|
||||
"name_kr": "유지보수 도로",
|
||||
"prompt": "railway maintenance road, service road alongside track, unpaved railway access road",
|
||||
"keywords": ["service road", "maintenance road", "access road"],
|
||||
"color_bgr": [80, 120, 160],
|
||||
"conf": 0.30,
|
||||
"priority": 10
|
||||
},
|
||||
{
|
||||
"name": "culvert",
|
||||
"name_kr": "암거/소교량",
|
||||
"prompt": "railway culvert, drainage culvert, small underpass tunnel, concrete drainage structure",
|
||||
"keywords": ["culvert", "underpass", "drainage tunnel"],
|
||||
"color_bgr": [180, 60, 180],
|
||||
"conf": 0.20,
|
||||
"priority": 9
|
||||
},
|
||||
{
|
||||
"name": "vehicle",
|
||||
"name_kr": "차량",
|
||||
"prompt": "car, truck, vehicle, automobile, van",
|
||||
"keywords": ["car", "truck", "vehicle", "automobile", "van"],
|
||||
"color_bgr": [255, 255, 255],
|
||||
"conf": 0.25,
|
||||
"priority": 2
|
||||
},
|
||||
{
|
||||
"name": "building",
|
||||
"name_kr": "건물",
|
||||
"prompt": "building, house, rooftop, structure, roof",
|
||||
"keywords": ["building", "house", "rooftop", "structure", "roof"],
|
||||
"color_bgr": [50, 50, 255],
|
||||
"conf": 0.25,
|
||||
"priority": 8
|
||||
},
|
||||
{
|
||||
"name": "farmland",
|
||||
"name_kr": "농지",
|
||||
"prompt": "farmland, agricultural field, cropland, vegetable garden",
|
||||
"keywords": ["farmland", "field", "cropland", "vegetable", "agricultural"],
|
||||
"color_bgr": [50, 200, 50],
|
||||
"conf": 0.25,
|
||||
"priority": 11
|
||||
},
|
||||
{
|
||||
"name": "vegetation",
|
||||
"name_kr": "식생",
|
||||
"prompt": "trees, forest, shrubs, vegetation, bushes",
|
||||
"keywords": ["tree", "forest", "shrub", "vegetation", "bush"],
|
||||
"color_bgr": [0, 120, 0],
|
||||
"conf": 0.25,
|
||||
"priority": 12
|
||||
}
|
||||
]
|
||||
}
|
||||
301
tools/auto_rail_detect.py
Normal file
301
tools/auto_rail_detect.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
auto_rail_detect.py
|
||||
===================
|
||||
항공 이미지에서 레일 라인을 자동 검출하여 Rhino용 DXF로 저장.
|
||||
수동 라벨링 없이 이미지 1장에서 바로 실행.
|
||||
|
||||
사용법:
|
||||
python tools/auto_rail_detect.py <image.png> [output.dxf]
|
||||
|
||||
예시:
|
||||
python tools/auto_rail_detect.py "경부선_2-2구간_미션33-34(5cm).png"
|
||||
python tools/auto_rail_detect.py "경부선.png" output/rail.dxf
|
||||
|
||||
원리:
|
||||
이미지 → 엣지 검출 → Hough 직선 검출 → 방향별 클러스터링 → DXF 폴리라인
|
||||
"""
|
||||
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ─── 설정 (조정 가능) ─────────────────────────────────────────────────────────
|
||||
CANNY_LOW = 30 # Canny 엣지 낮은 임계값 (낮을수록 엣지 많이 검출)
|
||||
CANNY_HIGH = 80 # Canny 엣지 높은 임계값
|
||||
HOUGH_THRESHOLD = 100 # Hough 투표 임계값 (높을수록 긴 선만 검출)
|
||||
HOUGH_MIN_LEN = 200 # 최소 선 길이 (픽셀)
|
||||
HOUGH_MAX_GAP = 30 # 선 간격 허용 (픽셀, 클수록 끊긴 선도 연결)
|
||||
ANGLE_TOLERANCE = 5 # 같은 방향으로 볼 각도 범위 (도)
|
||||
CLUSTER_DIST = 20 # 같은 레일로 볼 선 간격 (픽셀)
|
||||
MIN_TOTAL_LEN = 500 # 최종 레일 최소 총 길이 (픽셀, 짧은 잡선 제거)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def detect_edges(img_gray):
|
||||
"""CLAHE 대비 강화 → Gaussian 블러 → Canny 엣지"""
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(img_gray)
|
||||
blurred = cv2.GaussianBlur(enhanced, (5, 5), 0)
|
||||
edges = cv2.Canny(blurred, CANNY_LOW, CANNY_HIGH, apertureSize=3)
|
||||
return edges
|
||||
|
||||
|
||||
def detect_lines(edges):
|
||||
"""확률적 Hough 변환으로 직선 검출"""
|
||||
lines = cv2.HoughLinesP(
|
||||
edges,
|
||||
rho=1,
|
||||
theta=np.pi / 180,
|
||||
threshold=HOUGH_THRESHOLD,
|
||||
minLineLength=HOUGH_MIN_LEN,
|
||||
maxLineGap=HOUGH_MAX_GAP,
|
||||
)
|
||||
if lines is None:
|
||||
return []
|
||||
return [tuple(l[0]) for l in lines] # [(x1,y1,x2,y2), ...]
|
||||
|
||||
|
||||
def line_angle(x1, y1, x2, y2):
|
||||
"""선의 각도 (0~180도)"""
|
||||
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180
|
||||
return angle
|
||||
|
||||
|
||||
def line_length(x1, y1, x2, y2):
|
||||
return np.hypot(x2 - x1, y2 - y1)
|
||||
|
||||
|
||||
def line_midpoint(x1, y1, x2, y2):
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
|
||||
def perpendicular_dist(x1, y1, x2, y2, px, py):
|
||||
"""점 (px,py)에서 선 (x1,y1)-(x2,y2)까지 수직거리"""
|
||||
dx, dy = x2 - x1, y2 - y1
|
||||
length = np.hypot(dx, dy)
|
||||
if length == 0:
|
||||
return np.hypot(px - x1, py - y1)
|
||||
return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / length
|
||||
|
||||
|
||||
def cluster_lines(lines):
|
||||
"""비슷한 방향 + 가까운 위치의 선들을 같은 레일로 묶기"""
|
||||
if not lines:
|
||||
return []
|
||||
|
||||
# 각도별 그룹 먼저 분리
|
||||
angle_groups = {}
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line
|
||||
ang = round(line_angle(x1, y1, x2, y2) / ANGLE_TOLERANCE) * ANGLE_TOLERANCE
|
||||
angle_groups.setdefault(ang, []).append(line)
|
||||
|
||||
clusters = []
|
||||
for ang, grp in angle_groups.items():
|
||||
# 같은 방향 그룹 내에서 거리 기반 클러스터링
|
||||
used = [False] * len(grp)
|
||||
for i, line_i in enumerate(grp):
|
||||
if used[i]:
|
||||
continue
|
||||
cluster = [line_i]
|
||||
used[i] = True
|
||||
mx_i, my_i = line_midpoint(*line_i)
|
||||
for j, line_j in enumerate(grp):
|
||||
if used[j]:
|
||||
continue
|
||||
mx_j, my_j = line_midpoint(*line_j)
|
||||
dist = perpendicular_dist(*line_i, mx_j, my_j)
|
||||
if dist < CLUSTER_DIST:
|
||||
cluster.append(line_j)
|
||||
used[j] = True
|
||||
clusters.append(cluster)
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def merge_cluster_to_polyline(cluster):
|
||||
"""클러스터의 선들을 하나의 정렬된 폴리라인으로 합치기"""
|
||||
# 모든 끝점 수집
|
||||
all_pts = []
|
||||
for x1, y1, x2, y2 in cluster:
|
||||
all_pts.append((x1, y1))
|
||||
all_pts.append((x2, y2))
|
||||
|
||||
if not all_pts:
|
||||
return []
|
||||
|
||||
# 주성분 방향으로 정렬
|
||||
pts_arr = np.array(all_pts, dtype=float)
|
||||
mean = pts_arr.mean(axis=0)
|
||||
centered = pts_arr - mean
|
||||
cov = np.cov(centered.T)
|
||||
if cov.ndim < 2:
|
||||
direction = np.array([1.0, 0.0])
|
||||
else:
|
||||
eigenvalues, eigenvectors = np.linalg.eig(cov)
|
||||
direction = eigenvectors[:, np.argmax(eigenvalues)]
|
||||
|
||||
# 방향으로 투영하여 정렬
|
||||
projections = centered.dot(direction)
|
||||
sorted_idx = np.argsort(projections)
|
||||
sorted_pts = pts_arr[sorted_idx]
|
||||
|
||||
# 중복 제거 (가까운 점 합치기)
|
||||
merged = [sorted_pts[0]]
|
||||
for pt in sorted_pts[1:]:
|
||||
if np.hypot(pt[0] - merged[-1][0], pt[1] - merged[-1][1]) > 5:
|
||||
merged.append(pt)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def total_polyline_length(polyline):
|
||||
if len(polyline) < 2:
|
||||
return 0
|
||||
total = 0
|
||||
for i in range(len(polyline) - 1):
|
||||
total += np.hypot(
|
||||
polyline[i+1][0] - polyline[i][0],
|
||||
polyline[i+1][1] - polyline[i][1]
|
||||
)
|
||||
return total
|
||||
|
||||
|
||||
def save_debug_image(img, all_polylines, output_path):
|
||||
"""검출 결과를 이미지에 시각화 (확인용)"""
|
||||
vis = img.copy() if len(img.shape) == 3 else cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255),
|
||||
(255, 0, 255), (255, 255, 0)]
|
||||
for i, poly in enumerate(all_polylines):
|
||||
color = colors[i % len(colors)]
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
|
||||
cv2.polylines(vis, [pts], False, color, 2)
|
||||
# 번호 표시
|
||||
cx, cy = int(poly[len(poly)//2][0]), int(poly[len(poly)//2][1])
|
||||
cv2.putText(vis, str(i+1), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.5, color, 3)
|
||||
cv2.imwrite(str(output_path), vis)
|
||||
print(f" 시각화 저장: {output_path}")
|
||||
|
||||
|
||||
def process(image_path: str, dxf_path: str):
|
||||
import ezdxf
|
||||
|
||||
print(f"[입력] {image_path}")
|
||||
# 한글 경로 대응: numpy로 읽어서 cv2 디코딩
|
||||
import numpy as np_io
|
||||
with open(image_path, "rb") as f:
|
||||
data = f.read()
|
||||
img_arr = np_io.frombuffer(data, dtype=np_io.uint8)
|
||||
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print(f"오류: 이미지를 열 수 없습니다: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
H, W = img.shape[:2]
|
||||
print(f"[이미지] {W} x {H} px")
|
||||
|
||||
# 전처리: 작은 이미지로 축소 (속도)
|
||||
scale = 1.0
|
||||
if W > 4000:
|
||||
scale = 4000 / W
|
||||
small = cv2.resize(img, (int(W * scale), int(H * scale)))
|
||||
print(f" 축소: {int(W*scale)} x {int(H*scale)} (scale={scale:.2f})")
|
||||
else:
|
||||
small = img.copy()
|
||||
|
||||
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 엣지 검출
|
||||
print("[1단계] 엣지 검출...")
|
||||
edges = detect_edges(gray)
|
||||
edge_count = int(edges.sum() / 255)
|
||||
print(f" 엣지 픽셀: {edge_count:,}")
|
||||
|
||||
# Hough 직선 검출
|
||||
print("[2단계] Hough 직선 검출...")
|
||||
lines = detect_lines(edges)
|
||||
print(f" 검출된 선: {len(lines)}개")
|
||||
|
||||
if not lines:
|
||||
print("선을 검출하지 못했습니다. HOUGH_THRESHOLD를 낮춰보세요.")
|
||||
sys.exit(1)
|
||||
|
||||
# 클러스터링
|
||||
print("[3단계] 레일 클러스터링...")
|
||||
clusters = cluster_lines(lines)
|
||||
print(f" 클러스터: {len(clusters)}개")
|
||||
|
||||
# 폴리라인 변환 + 길이 필터
|
||||
polylines = []
|
||||
for cluster in clusters:
|
||||
poly = merge_cluster_to_polyline(cluster)
|
||||
length = total_polyline_length(poly)
|
||||
if length >= MIN_TOTAL_LEN / scale:
|
||||
# 원본 해상도로 역변환
|
||||
poly_orig = [(p[0] / scale, p[1] / scale) for p in poly]
|
||||
polylines.append((poly_orig, length / scale))
|
||||
|
||||
# 길이순 정렬
|
||||
polylines.sort(key=lambda x: -x[1])
|
||||
print(f" 유효 레일 라인: {len(polylines)}개")
|
||||
|
||||
if not polylines:
|
||||
print("유효한 레일을 찾지 못했습니다. MIN_TOTAL_LEN을 낮춰보세요.")
|
||||
sys.exit(1)
|
||||
|
||||
for i, (poly, length) in enumerate(polylines):
|
||||
print(f" 레일 {i+1}: {length:.0f}px ({length*0.05:.1f}m)")
|
||||
|
||||
# DXF 저장
|
||||
print("[4단계] DXF 저장...")
|
||||
doc = ezdxf.new("R2010")
|
||||
msp = doc.modelspace()
|
||||
doc.layers.add("RAIL_AUTO", color=1) # 빨강 — 자동검출 레일
|
||||
doc.layers.add("RAIL_AUTO_LABEL", color=2) # 노랑 — 번호
|
||||
|
||||
for i, (poly, length) in enumerate(polylines):
|
||||
# Y축 반전 (이미지→DXF)
|
||||
dxf_pts = [(float(p[0]), float(-p[1])) for p in poly]
|
||||
msp.add_lwpolyline(dxf_pts, dxfattribs={"layer": "RAIL_AUTO"})
|
||||
|
||||
# 중간점에 번호 텍스트
|
||||
mid = poly[len(poly) // 2]
|
||||
msp.add_text(
|
||||
f"Rail_{i+1}",
|
||||
dxfattribs={
|
||||
"layer": "RAIL_AUTO_LABEL",
|
||||
"height": 20,
|
||||
"insert": (float(mid[0]), float(-mid[1])),
|
||||
}
|
||||
)
|
||||
|
||||
Path(dxf_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.saveas(dxf_path)
|
||||
print(f"[완료] DXF 저장: {dxf_path}")
|
||||
print(f" 레이어: RAIL_AUTO(빨강) = {len(polylines)}개 레일 중심선")
|
||||
|
||||
# 디버그 이미지 저장
|
||||
debug_path = Path(dxf_path).parent / (Path(dxf_path).stem + "_debug.jpg")
|
||||
all_polys = [p for p, _ in polylines]
|
||||
save_debug_image(small, [([(x*scale, y*scale) for x, y in p]) for p in all_polys],
|
||||
debug_path)
|
||||
|
||||
print(f"\nRhino 사용법:")
|
||||
print(f" 1. File -> Import -> {Path(dxf_path).name}")
|
||||
print(f" 2. RAIL_AUTO 레이어 선택")
|
||||
print(f" 3. Sweep1 명령 -> 레일 단면 선택 -> 3D 레일 완성")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("사용법: python tools/auto_rail_detect.py <image.png> [output.dxf]")
|
||||
sys.exit(1)
|
||||
|
||||
img_path = sys.argv[1]
|
||||
dxf_path = sys.argv[2] if len(sys.argv) >= 3 else str(
|
||||
Path(img_path).with_suffix(".dxf")
|
||||
)
|
||||
process(img_path, dxf_path)
|
||||
565
tools/detect_all_objects.py
Normal file
565
tools/detect_all_objects.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
이미지에서 객체를 SAM3.1로 검출하여 색상별로 시각화.
|
||||
|
||||
전략:
|
||||
- cols×rows 타일로 분할 (overlap 중복)
|
||||
- --tiles 로 처리할 타일 번호 지정 (예: 9-24, 1,5,9, 전체=all)
|
||||
- --categories 로 JSON 설정 파일 로드 (카테고리·프롬프트·색상 정의)
|
||||
- 타일당 SAM3.1 1회 호출 (모든 카테고리 프롬프트 합산)
|
||||
- 병렬 처리(ThreadPoolExecutor) → NMS → 색상 시각화
|
||||
|
||||
사용법:
|
||||
python tools/detect_all_objects.py \\
|
||||
--input data/역사이미지/slope/DJI_20260306113839_0005.JPG \\
|
||||
--categories configs/railway_zone.json \\
|
||||
--tiles 9-24 \\
|
||||
--cols 8 --rows 6 --overlap 0.10 --workers 4
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
SAM3_MODEL_ID = "segment_anything_3"
|
||||
|
||||
# 기본 카테고리 (--categories 미지정 시 사용)
|
||||
_DEFAULT_CATEGORIES = [
|
||||
{"name": "railway", "prompt": "railroad track, railway rail, steel rail",
|
||||
"keywords": ["railroad", "railway rail", "steel rail"], "color_bgr": [0, 200, 255]},
|
||||
{"name": "catenary_pole", "prompt": "railway catenary pole, overhead line pole, catenary mast",
|
||||
"keywords": ["catenary pole", "overhead line pole", "catenary mast"], "color_bgr": [255, 130, 0]},
|
||||
{"name": "highway", "prompt": "highway road, expressway asphalt, paved road lane",
|
||||
"keywords": ["highway", "expressway", "paved road"], "color_bgr": [160, 160, 160]},
|
||||
{"name": "vehicle", "prompt": "car, truck, vehicle, automobile",
|
||||
"keywords": ["car", "truck", "vehicle", "automobile"], "color_bgr": [0, 255, 0]},
|
||||
{"name": "building", "prompt": "building, house, rooftop, structure",
|
||||
"keywords": ["building", "house", "rooftop", "structure"], "color_bgr": [50, 50, 255]},
|
||||
{"name": "farmland", "prompt": "farmland, agricultural field, cropland, vegetable garden",
|
||||
"keywords": ["farmland", "field", "cropland", "vegetable"], "color_bgr": [50, 200, 50]},
|
||||
{"name": "vegetation", "prompt": "trees, forest, shrubs, vegetation, bushes",
|
||||
"keywords": ["tree", "forest", "shrub", "vegetation", "bush"], "color_bgr": [0, 120, 0]},
|
||||
{"name": "guardrail", "prompt": "guardrail, highway barrier, road fence, crash barrier",
|
||||
"keywords": ["guardrail", "highway barrier", "road fence", "crash barrier"], "color_bgr": [200, 0, 200]},
|
||||
{"name": "bridge", "prompt": "bridge, overpass, viaduct",
|
||||
"keywords": ["bridge", "overpass", "viaduct"], "color_bgr": [0, 165, 255]},
|
||||
{"name": "wire", "prompt": "overhead wire, catenary wire, electric cable line",
|
||||
"keywords": ["catenary wire", "overhead wire", "electric cable"], "color_bgr": [200, 200, 255]},
|
||||
]
|
||||
|
||||
|
||||
# ── 타일 번호 파싱 ────────────────────────────────────────────────────────────
|
||||
def parse_tiles(tile_str: str, total: int) -> set:
|
||||
"""'9-24', '1,3,5', 'all' → tile index 집합 (1-based)."""
|
||||
if tile_str.lower() == "all":
|
||||
return set(range(1, total + 1))
|
||||
result = set()
|
||||
for part in tile_str.split(","):
|
||||
part = part.strip()
|
||||
if "-" in part:
|
||||
a, b = part.split("-", 1)
|
||||
result.update(range(int(a), int(b) + 1))
|
||||
else:
|
||||
result.add(int(part))
|
||||
return result
|
||||
|
||||
|
||||
# ── 카테고리 로드 ─────────────────────────────────────────────────────────────
|
||||
def load_categories(json_path: str | None) -> list:
|
||||
if json_path:
|
||||
data = json.loads(Path(json_path).read_text(encoding="utf-8"))
|
||||
return data["categories"]
|
||||
return _DEFAULT_CATEGORIES
|
||||
|
||||
|
||||
def label_to_category(label: str, categories: list) -> int:
|
||||
label_l = label.lower()
|
||||
for i, cat in enumerate(categories):
|
||||
for kw in cat["keywords"]:
|
||||
if kw in label_l:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def build_combined_prompt(categories: list) -> str:
|
||||
return ", ".join(cat["prompt"] for cat in categories)
|
||||
|
||||
|
||||
# ── SAM3 호출 ─────────────────────────────────────────────────────────────────
|
||||
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
|
||||
h, w = image_bgr.shape[:2]
|
||||
scale = 1.0
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
|
||||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
return base64.b64encode(buf).decode("utf-8"), scale
|
||||
|
||||
|
||||
def sam3_segment_tile(tile_bgr: np.ndarray, prompt: str, conf: float) -> list:
|
||||
b64, scale = encode_image(tile_bgr)
|
||||
payload = {
|
||||
"model": SAM3_MODEL_ID,
|
||||
"image": b64,
|
||||
"params": {"text_prompt": prompt, "conf_threshold": conf,
|
||||
"show_masks": True, "show_boxes": False},
|
||||
}
|
||||
try:
|
||||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if not resp.get("success"):
|
||||
return []
|
||||
shapes = resp.get("data", {}).get("shapes", [])
|
||||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||||
if scale < 1.0:
|
||||
inv = 1.0 / scale
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||||
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
# ── NMS ───────────────────────────────────────────────────────────────────────
|
||||
def _bbox(pts):
|
||||
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
|
||||
def _nms_core(shapes: list, iou_thresh: float) -> list:
|
||||
"""IoU 기반 NMS. shapes 각 항목에 score 필드 필요."""
|
||||
if not shapes:
|
||||
return []
|
||||
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
|
||||
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while len(order):
|
||||
i = order[0]; keep.append(i)
|
||||
if len(order) == 1: break
|
||||
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
|
||||
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
|
||||
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
|
||||
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
|
||||
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
|
||||
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
|
||||
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
|
||||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||||
order = order[1:][iou < iou_thresh]
|
||||
return [shapes[i] for i in keep]
|
||||
|
||||
|
||||
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
|
||||
return _nms_core(shapes, iou_thresh)
|
||||
|
||||
|
||||
def _poly_orient(points: list, H: int, W: int) -> str: # post_merge_poles.py에서도 사용
|
||||
"""폴리곤 장축 방향 판별 (render_skeleton_overlay.py 동일 로직).
|
||||
|
||||
V: 장축이 이미지 중심에서 방사형 방향과 정렬 (cos_sim > 0.7) → 세로 기둥
|
||||
H: 장축이 radial 직교 방향 → 수평 빔
|
||||
?: aspect ratio < 1.3 으로 판별 불가
|
||||
"""
|
||||
pts = np.array(points, dtype=np.float32)
|
||||
rect = cv2.minAreaRect(pts)
|
||||
(rx, ry), (rw, rh), angle = rect
|
||||
if min(rw, rh) < 1:
|
||||
return '?'
|
||||
ar = max(rw, rh) / min(rw, rh)
|
||||
if ar < 1.3:
|
||||
return '?'
|
||||
long_angle_deg = angle if rw >= rh else angle + 90
|
||||
lx = float(np.cos(np.radians(long_angle_deg)))
|
||||
ly = float(np.sin(np.radians(long_angle_deg)))
|
||||
img_cx, img_cy = W / 2.0, H / 2.0
|
||||
rdx, rdy = rx - img_cx, ry - img_cy
|
||||
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
|
||||
if radial_norm < 1:
|
||||
return '?'
|
||||
rdx, rdy = rdx / radial_norm, rdy / radial_norm
|
||||
cos_sim = abs(lx * rdx + ly * rdy)
|
||||
return 'V' if cos_sim > 0.7 else 'H'
|
||||
|
||||
|
||||
def merge_nonramen_poles(shapes: list, H: int, W: int,
|
||||
x_overlap_thresh: float = 0.30,
|
||||
y_gap_thresh: int = 150) -> list:
|
||||
"""타일 경계 분할된 전철주 병합 — V+V 조합만 허용.
|
||||
|
||||
_poly_orient로 각 폴리곤 V/H 분류.
|
||||
두 폴리곤 모두 V(세로 기둥)이고 공간 기준 충족 시만 병합.
|
||||
H(수평 빔) 포함 쌍 = 라멘 관련 조각 → 병합 건너뜀.
|
||||
"""
|
||||
if len(shapes) <= 1:
|
||||
return shapes
|
||||
|
||||
orients = [_poly_orient(s["points"], H, W) for s in shapes]
|
||||
v_count = sum(1 for o in orients if o == 'V')
|
||||
h_count = sum(1 for o in orients if o == 'H')
|
||||
print(f" [orient] V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}")
|
||||
|
||||
def get_bbox(s):
|
||||
xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
def x_overlap_ratio(b1, b2):
|
||||
ox = min(b1[2], b2[2]) - max(b1[0], b2[0])
|
||||
ux = max(b1[2], b2[2]) - min(b1[0], b2[0])
|
||||
return ox / ux if ux > 0 else 0.0
|
||||
|
||||
def y_gap(b1, b2):
|
||||
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
|
||||
|
||||
def merge_two(s1, s2):
|
||||
mask = np.zeros((H, W), dtype=np.uint8)
|
||||
for s in (s1, s2):
|
||||
cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return s1
|
||||
c = max(contours, key=cv2.contourArea)
|
||||
eps = 0.002 * cv2.arcLength(c, True)
|
||||
approx = cv2.approxPolyDP(c, eps, True)
|
||||
merged = dict(s1)
|
||||
merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0)))
|
||||
return merged
|
||||
|
||||
merged_flags = [False] * len(shapes)
|
||||
result = []
|
||||
merged_count = 0
|
||||
for i in range(len(shapes)):
|
||||
if merged_flags[i]:
|
||||
continue
|
||||
cur = shapes[i]
|
||||
cur_ori = orients[i]
|
||||
cb = get_bbox(cur)
|
||||
for j in range(i + 1, len(shapes)):
|
||||
if merged_flags[j]:
|
||||
continue
|
||||
if cur_ori != 'V' or orients[j] != 'V':
|
||||
continue # 둘 다 V가 아니면 병합 안 함
|
||||
jb = get_bbox(shapes[j])
|
||||
if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh:
|
||||
cur = merge_two(cur, shapes[j])
|
||||
cur_ori = 'V'
|
||||
cb = get_bbox(cur)
|
||||
merged_flags[j] = True
|
||||
merged_count += 1
|
||||
result.append(cur)
|
||||
print(f" [merge] 병합={merged_count}쌍")
|
||||
return result
|
||||
|
||||
|
||||
def cross_class_nms(buckets: list, categories: list, iou_thresh: float) -> list:
|
||||
"""클래스 간 NMS: 동일 영역에 다른 클래스가 중복 검출될 때 우선순위 높은 쪽 보존.
|
||||
|
||||
정렬 기준: (priority 오름차순, score 내림차순)
|
||||
→ priority 낮은 값(=중요 클래스)이 우선 보존됨.
|
||||
"""
|
||||
# 모든 shape에 클래스 인덱스 태깅
|
||||
tagged = []
|
||||
for i, shapes in enumerate(buckets):
|
||||
priority = categories[i].get("priority", 99)
|
||||
for s in shapes:
|
||||
tagged.append((priority, -float(s.get("score", 0.5)), i, s))
|
||||
|
||||
# priority 오름차순, score 내림차순 정렬
|
||||
tagged.sort(key=lambda x: (x[0], x[1]))
|
||||
|
||||
if not tagged:
|
||||
return [[] for _ in buckets]
|
||||
|
||||
all_shapes = [t[3] for t in tagged]
|
||||
cls_ids = [t[2] for t in tagged]
|
||||
|
||||
bboxes = np.array([_bbox(s["points"]) for s in all_shapes], dtype=np.float32)
|
||||
suppressed = [False] * len(all_shapes)
|
||||
|
||||
for i in range(len(all_shapes)):
|
||||
if suppressed[i]:
|
||||
continue
|
||||
for j in range(i + 1, len(all_shapes)):
|
||||
if suppressed[j]:
|
||||
continue
|
||||
xx1 = max(bboxes[i,0], bboxes[j,0])
|
||||
yy1 = max(bboxes[i,1], bboxes[j,1])
|
||||
xx2 = min(bboxes[i,2], bboxes[j,2])
|
||||
yy2 = min(bboxes[i,3], bboxes[j,3])
|
||||
inter = max(0, xx2-xx1) * max(0, yy2-yy1)
|
||||
if inter == 0:
|
||||
continue
|
||||
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
|
||||
a_j = (bboxes[j,2]-bboxes[j,0])*(bboxes[j,3]-bboxes[j,1])
|
||||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||||
if iou >= iou_thresh:
|
||||
suppressed[j] = True # i가 우선순위 높으므로 j 제거
|
||||
|
||||
new_buckets = [[] for _ in buckets]
|
||||
for i, (keep, cls_i, s) in enumerate(zip(suppressed, cls_ids, all_shapes)):
|
||||
if not keep:
|
||||
new_buckets[cls_i].append(s)
|
||||
return new_buckets
|
||||
|
||||
|
||||
# ── 타일 그리드 검출 (병렬) ───────────────────────────────────────────────────
|
||||
def detect_tiled(image_bgr: np.ndarray, cols: int, rows: int, overlap: float,
|
||||
conf: float, workers: int, tile_filter: set,
|
||||
prompt: str) -> list:
|
||||
H, W = image_bgr.shape[:2]
|
||||
base_w = W / cols
|
||||
base_h = H / rows
|
||||
pad_x = int(base_w * overlap)
|
||||
pad_y = int(base_h * overlap)
|
||||
|
||||
tiles = []
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
idx = r * cols + c + 1
|
||||
if idx not in tile_filter:
|
||||
continue
|
||||
x0 = max(0, int(c * base_w) - pad_x)
|
||||
x1 = min(W, int((c + 1) * base_w) + pad_x)
|
||||
y0 = max(0, int(r * base_h) - pad_y)
|
||||
y1 = min(H, int((r + 1) * base_h) + pad_y)
|
||||
tiles.append((idx, x0, y0, x1, y1))
|
||||
|
||||
total = len(tiles)
|
||||
done = [0]
|
||||
all_shapes = []
|
||||
|
||||
def process(args):
|
||||
idx, x0, y0, x1, y1 = args
|
||||
tile = image_bgr[y0:y1, x0:x1]
|
||||
shapes = sam3_segment_tile(tile, prompt, conf)
|
||||
for s in shapes:
|
||||
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
|
||||
return shapes
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
futs = {ex.submit(process, t): t for t in tiles}
|
||||
for fut in as_completed(futs):
|
||||
all_shapes.extend(fut.result())
|
||||
done[0] += 1
|
||||
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r")
|
||||
print()
|
||||
return all_shapes
|
||||
|
||||
|
||||
# ── 시각화 ────────────────────────────────────────────────────────────────────
|
||||
def draw_detections(image_bgr: np.ndarray, buckets: list,
|
||||
categories: list, tile_filter: set,
|
||||
cols: int, rows: int, overlap: float) -> np.ndarray:
|
||||
vis = image_bgr.copy()
|
||||
H, W = vis.shape[:2]
|
||||
|
||||
# 처리된 타일 경계 표시
|
||||
base_w = W / cols
|
||||
base_h = H / rows
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
idx = r * cols + c + 1
|
||||
if idx in tile_filter:
|
||||
bx0, by0 = int(c * base_w), int(r * base_h)
|
||||
bx1, by1 = min(W, int((c+1)*base_w)), min(H, int((r+1)*base_h))
|
||||
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (255, 255, 255), 1)
|
||||
|
||||
# 마스크 + 순번 레이블
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_sc = max(0.4, min(W, H) / 8000)
|
||||
thickness = max(1, int(font_sc * 2))
|
||||
|
||||
for i, cat in enumerate(categories):
|
||||
color = tuple(cat["color_bgr"])
|
||||
prefix = cat["name"][:3].upper() # 예: RAI, CAT, BRA …
|
||||
for seq, s in enumerate(buckets[i], start=1):
|
||||
pts = np.array(s["points"], dtype=np.int32)
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 2)
|
||||
|
||||
# 무게중심에 순번 표시
|
||||
cx = int(np.mean(pts[:, 0]))
|
||||
cy = int(np.mean(pts[:, 1]))
|
||||
score = float(s.get("score", 0.0))
|
||||
label = f"{prefix}{seq:03d} {score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(label, font, font_sc, thickness)
|
||||
tx, ty = cx - tw // 2, cy + th // 2
|
||||
# 배경 박스는 검정, 텍스트는 흰색으로 변경
|
||||
cv2.rectangle(vis, (tx - 2, ty - th - 2), (tx + tw + 2, ty + 2),
|
||||
(0, 0, 0), -1)
|
||||
cv2.putText(vis, label, (tx, ty), font, font_sc,
|
||||
(255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
# 범례
|
||||
lx, ly = W - 280, 20
|
||||
panel_h = len(categories) * 24 + 10
|
||||
vis[0:panel_h, lx-8:W] = (vis[0:panel_h, lx-8:W] * 0.35).astype(np.uint8)
|
||||
for i, cat in enumerate(categories):
|
||||
color = tuple(cat["color_bgr"])
|
||||
prefix = cat["name"][:3].upper()
|
||||
cv2.rectangle(vis, (lx, ly-13), (lx+15, ly+3), color, -1)
|
||||
cv2.putText(vis, f"[{prefix}] {cat['name']} ({len(buckets[i])})",
|
||||
(lx+20, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.50, color, 1, cv2.LINE_AA)
|
||||
ly += 24
|
||||
|
||||
return vis
|
||||
|
||||
|
||||
# ── 메인 ──────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description=__doc__)
|
||||
ap.add_argument("--input", required=True, help="입력 이미지 경로")
|
||||
ap.add_argument("--output", default=None, help="출력 이미지 경로 (기본: 입력파일명_out.jpg)")
|
||||
ap.add_argument("--categories", default=None, help="카테고리 JSON 경로 (기본: 내장 10개)")
|
||||
ap.add_argument("--tiles", default="all", help="처리할 타일 번호: 9-24 / 1,5,9 / all (기본: all)")
|
||||
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본: 8)")
|
||||
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본: 6)")
|
||||
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본: 0.10)")
|
||||
ap.add_argument("--conf", type=float, default=0.20, help="SAM3 신뢰도 임계값 (기본: 0.20)")
|
||||
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본: 4)")
|
||||
ap.add_argument("--save-labels", action="store_true", help="YOLO 폴리곤 포맷 .txt 라벨 파일 저장")
|
||||
ap.add_argument("--save-json", action="store_true", help="AnyLabeling JSON 포맷 저장 (railway.json 동일 양식)")
|
||||
args = ap.parse_args()
|
||||
|
||||
img_path = Path(args.input)
|
||||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
print(f"이미지 로드 실패: {img_path}"); return
|
||||
|
||||
H, W = image_bgr.shape[:2]
|
||||
total_tiles = args.cols * args.rows
|
||||
tile_filter = parse_tiles(args.tiles, total_tiles)
|
||||
categories = load_categories(args.categories)
|
||||
combined_prompt = build_combined_prompt(categories)
|
||||
|
||||
# cross-class NMS IoU: JSON > 기본값 0.45
|
||||
cc_iou = 0.45
|
||||
if args.categories:
|
||||
raw = json.loads(Path(args.categories).read_text(encoding="utf-8"))
|
||||
cc_iou = raw.get("cross_class_nms_iou", cc_iou)
|
||||
|
||||
# SAM3 호출용 conf = 모든 카테고리 conf 중 최솟값 (낮은 쪽부터 받아 후처리로 필터)
|
||||
sam3_conf = min(cat.get("conf", args.conf) for cat in categories)
|
||||
|
||||
print(f"이미지 : {W}×{H}")
|
||||
print(f"타일 그리드: {args.cols}×{args.rows}={total_tiles}개 | 처리 대상: {sorted(tile_filter)}")
|
||||
print(f"카테고리 : {len(categories)}개 | 중복: {args.overlap*100:.0f}%")
|
||||
print(f"SAM3 conf : {sam3_conf} (전체 최솟값) | cross-class NMS IoU: {cc_iou}")
|
||||
print(f"SAM3 호출 : {len(tile_filter)}회 | 병렬: {args.workers}스레드\n")
|
||||
|
||||
t0 = time.time()
|
||||
all_shapes = detect_tiled(image_bgr, args.cols, args.rows, args.overlap,
|
||||
sam3_conf, args.workers, tile_filter, combined_prompt)
|
||||
|
||||
print(f"전체 검출 {len(all_shapes)}개 → 분류 + per-class conf 필터 + NMS...")
|
||||
buckets = [[] for _ in categories]
|
||||
unmatched = 0
|
||||
for s in all_shapes:
|
||||
idx = label_to_category(s.get("label", ""), categories)
|
||||
if idx < 0:
|
||||
unmatched += 1
|
||||
continue
|
||||
# per-class conf 필터
|
||||
cat_conf = categories[idx].get("conf", args.conf)
|
||||
if float(s.get("score", 0.0)) < cat_conf:
|
||||
continue
|
||||
buckets[idx].append(s)
|
||||
|
||||
# 클래스 내 NMS
|
||||
print(" [1] 클래스 내 NMS")
|
||||
for i, cat in enumerate(categories):
|
||||
before = len(buckets[i])
|
||||
buckets[i] = nms_shapes(buckets[i])
|
||||
print(f" {cat['name']:18s}: {before:3d} → {len(buckets[i]):3d}개 (conf≥{cat.get('conf', args.conf)})")
|
||||
|
||||
# 클래스 간 NMS
|
||||
total_before = sum(len(b) for b in buckets)
|
||||
print(f" [2] cross-class NMS (IoU≥{cc_iou})")
|
||||
buckets = cross_class_nms(buckets, categories, cc_iou)
|
||||
total_after = sum(len(b) for b in buckets)
|
||||
print(f" {total_before}개 → {total_after}개")
|
||||
|
||||
if unmatched:
|
||||
print(f" (미분류/conf미달 {unmatched}개 제외)")
|
||||
print(f"\n완료: {time.time()-t0:.0f}초")
|
||||
|
||||
vis = draw_detections(image_bgr, buckets, categories,
|
||||
tile_filter, args.cols, args.rows, args.overlap)
|
||||
h, w = vis.shape[:2]
|
||||
if max(h, w) > 4096:
|
||||
s = 4096 / max(h, w)
|
||||
vis = cv2.resize(vis, (int(w*s), int(h*s)))
|
||||
|
||||
if args.output:
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
tile_tag = args.tiles.replace(",", "_").replace("-", "to")
|
||||
cat_tag = Path(args.categories).stem if args.categories else "default"
|
||||
out_dir = Path("output") / "detect" / img_path.stem
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
base_name = f"tiles{tile_tag}_{cat_tag}"
|
||||
n = 1
|
||||
while True:
|
||||
out_path = out_dir / f"{base_name}_{n:03d}.jpg"
|
||||
if not out_path.exists():
|
||||
break
|
||||
n += 1
|
||||
|
||||
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
|
||||
print(f"저장: {out_path}")
|
||||
|
||||
if args.save_labels:
|
||||
label_path = out_path.with_suffix(".txt")
|
||||
with open(label_path, "w", encoding="utf-8") as f:
|
||||
for cls_idx, shapes in enumerate(buckets):
|
||||
for s in shapes:
|
||||
pts_norm = [[px / W, py / H] for px, py in s["points"]]
|
||||
coords = " ".join(f"{x:.6f} {y:.6f}" for x, y in pts_norm)
|
||||
f.write(f"{cls_idx} {coords}\n")
|
||||
print(f"라벨 저장: {label_path}")
|
||||
|
||||
if args.save_json:
|
||||
import json as _json
|
||||
json_shapes = []
|
||||
for cls_idx, shapes in enumerate(buckets):
|
||||
cat_name = categories[cls_idx]["name"] if cls_idx < len(categories) else str(cls_idx)
|
||||
for s in shapes:
|
||||
json_shapes.append({
|
||||
"label": cat_name,
|
||||
"score": float(s.get("score", 0.0)),
|
||||
"points": [[float(px), float(py)] for px, py in s["points"]],
|
||||
"group_id": None,
|
||||
"description": None,
|
||||
"shape_type": "polygon",
|
||||
"flags": None,
|
||||
})
|
||||
anylabel = {
|
||||
"version": "3.3.9",
|
||||
"flags": {},
|
||||
"shapes": json_shapes,
|
||||
"imagePath": img_path.name,
|
||||
"imageData": None,
|
||||
"imageHeight": H,
|
||||
"imageWidth": W,
|
||||
}
|
||||
json_path = out_path.with_suffix(".json")
|
||||
json_path.write_text(_json.dumps(anylabel, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
print(f"JSON 저장: {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
tools/detect_hollow_section.py
Normal file
152
tools/detect_hollow_section.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
얇은 중공단면(도너츠 형태) 검출 및 표시 스크립트
|
||||
드론 영상에서 전철주 상단 원형 단면을 찾아 표시
|
||||
|
||||
사용법:
|
||||
python detect_hollow_section.py <image> [--radius <px>] [--tol <px>] [--topk <n>]
|
||||
|
||||
--radius : 3D 프로그램 줌 기준 단면 반지름 (픽셀). 미지정 시 자동 탐색.
|
||||
--tol : radius ± 허용 오차 (기본 30px)
|
||||
--topk : 표시할 최대 후보 수 (기본 3)
|
||||
"""
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ring_score(edges: np.ndarray, cx: int, cy: int, r: int,
|
||||
H: int, W: int, inner_ratio: float = 0.60) -> float:
|
||||
"""
|
||||
링(도넛) 특성 점수.
|
||||
- 링 영역(외곽~내곽) 엣지 강도 / 내부 엣지 강도 비율 (최대 5 클리핑)
|
||||
- 이미지 중심 근접도 보정 (중심에 가까울수록 가산점)
|
||||
Canny edges는 외부에서 한 번만 계산해 전달받음.
|
||||
"""
|
||||
inner_r = max(int(r * inner_ratio), 3)
|
||||
|
||||
mask_full = np.zeros((H, W), np.uint8)
|
||||
mask_inner = np.zeros((H, W), np.uint8)
|
||||
cv2.circle(mask_full, (cx, cy), r, 255, -1)
|
||||
cv2.circle(mask_inner, (cx, cy), inner_r, 255, -1)
|
||||
mask_ring = cv2.subtract(mask_full, mask_inner)
|
||||
|
||||
if mask_ring.any() == 0 or mask_inner.any() == 0:
|
||||
return 0.0
|
||||
|
||||
ring_edge = float(edges[mask_ring > 0].mean())
|
||||
inner_edge = float(edges[mask_inner > 0].mean()) + 1e-6
|
||||
edge_ratio = min(ring_edge / inner_edge, 5.0)
|
||||
|
||||
dist_from_center = np.hypot(cx - W / 2, cy - H / 2)
|
||||
max_dist = np.hypot(W / 2, H / 2)
|
||||
center_bonus = 1.0 - dist_from_center / max_dist # 0~1
|
||||
|
||||
return edge_ratio * (0.7 + 0.3 * center_bonus)
|
||||
|
||||
|
||||
def detect_hollow_section(
|
||||
image_path: str,
|
||||
output_path: str | None = None,
|
||||
radius_px: int | None = None, # 3D 프로그램 줌 기준 알려진 반지름
|
||||
tol_px: int = 30, # radius_px ± 허용 오차
|
||||
top_k: int = 3,
|
||||
) -> str:
|
||||
img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"이미지 로드 실패: {image_path}")
|
||||
|
||||
H, W = img.shape[:2]
|
||||
print(f"이미지 크기: {W}×{H}")
|
||||
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
blurred = cv2.GaussianBlur(clahe.apply(gray), (7, 7), 2)
|
||||
|
||||
# Canny 엣지를 미리 한 번만 계산 (ring_score에서 재사용)
|
||||
edges = cv2.Canny(blurred, 30, 90)
|
||||
|
||||
# 반경 범위 결정
|
||||
if radius_px is not None:
|
||||
min_r = max(5, radius_px - tol_px)
|
||||
max_r = radius_px + tol_px
|
||||
else:
|
||||
short = min(W, H)
|
||||
min_r = max(10, short // 20)
|
||||
max_r = short // 3
|
||||
|
||||
print(f"탐색 반경: {min_r} ~ {max_r} px"
|
||||
+ (f" (기준={radius_px}px ±{tol_px})" if radius_px else " (자동)"))
|
||||
|
||||
circles = cv2.HoughCircles(
|
||||
blurred,
|
||||
cv2.HOUGH_GRADIENT,
|
||||
dp=1.0,
|
||||
minDist=min_r * 3,
|
||||
param1=80,
|
||||
param2=35,
|
||||
minRadius=min_r,
|
||||
maxRadius=max_r,
|
||||
)
|
||||
|
||||
result = img.copy()
|
||||
|
||||
if circles is None:
|
||||
print("HoughCircles 미검출")
|
||||
else:
|
||||
circles = np.round(circles[0]).astype(int)
|
||||
print(f"HoughCircles 후보: {len(circles)}개 → 링 스코어 계산 중...")
|
||||
|
||||
scored = sorted(
|
||||
[(ring_score(edges, cx, cy, r, H, W), cx, cy, r)
|
||||
for cx, cy, r in circles],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# 1위: 녹색(굵게), 2위: 노랑, 3위: 파랑
|
||||
colors = [(0, 255, 0), (0, 220, 255), (255, 120, 0)]
|
||||
|
||||
print(f"\n상위 {min(top_k, len(scored))}개 중공단면 후보:")
|
||||
for rank, (score, cx, cy, r) in enumerate(scored[:top_k]):
|
||||
color = colors[rank % len(colors)]
|
||||
inner_r = max(int(r * 0.60), 5)
|
||||
lw = 3 if rank == 0 else 2
|
||||
|
||||
cv2.circle(result, (cx, cy), r, color, lw)
|
||||
cv2.circle(result, (cx, cy), inner_r, color, 1)
|
||||
cv2.circle(result, (cx, cy), 6, (0, 0, 255), -1)
|
||||
arm = r + 15
|
||||
cv2.line(result, (cx - arm, cy), (cx + arm, cy), color, 1)
|
||||
cv2.line(result, (cx, cy - arm), (cx, cy + arm), color, 1)
|
||||
cv2.putText(result, f"#{rank+1} r={r}px s={score:.2f}",
|
||||
(cx - r, cy - r - 8),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 2)
|
||||
print(f" #{rank+1}: 중심=({cx},{cy}), 반지름={r}px, 링스코어={score:.3f}")
|
||||
|
||||
if output_path is None:
|
||||
p = Path(image_path)
|
||||
output_path = str(p.parent / (p.stem + "_hollow_detected" + p.suffix))
|
||||
|
||||
cv2.imencode(Path(output_path).suffix, result)[1].tofile(output_path)
|
||||
print(f"\n결과 저장: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="얇은 중공단면 검출")
|
||||
parser.add_argument("image", nargs="?",
|
||||
default=r"d:\MYCLAUDE_PROJECT\x-anylabeling01\data\Message_2026-04-23T11_13_29+09_00.png")
|
||||
parser.add_argument("--radius", type=int, default=None,
|
||||
help="3D 프로그램 줌 기준 단면 반지름(px). 미지정 시 자동 탐색.")
|
||||
parser.add_argument("--tol", type=int, default=30,
|
||||
help="radius ± 허용 오차 (기본 30px)")
|
||||
parser.add_argument("--topk", type=int, default=3,
|
||||
help="표시할 최대 후보 수 (기본 3)")
|
||||
args = parser.parse_args()
|
||||
|
||||
detect_hollow_section(
|
||||
args.image,
|
||||
radius_px=args.radius,
|
||||
tol_px=args.tol,
|
||||
top_k=args.topk,
|
||||
)
|
||||
666
tools/detect_raamen.py
Normal file
666
tools/detect_raamen.py
Normal file
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
드론 사선 촬영 이미지에서 라멘형(門자형) 전철주 검출.
|
||||
기존 픽셀 위상수학(Skeleton) 방식을 공간 기하학 방식으로 교체.
|
||||
|
||||
파이프라인:
|
||||
Phase 1: 폴리곤 단순화(approxPolyDP) + 소실점(Vanishing Point) 계산
|
||||
Phase 2: 동적 V/H 분류 (소실점 기반 기대 각도)
|
||||
Phase 3: 근접성 기반 그룹핑 (H 앵커 → 아래 V 탐색)
|
||||
Phase 4: 라멘 구조 판정 + 예외(가림) 처리
|
||||
|
||||
사용:
|
||||
python tools/detect_raamen.py \
|
||||
--image <path> --label <path> --output <path> \
|
||||
[--class-ids 1] [--epsilon 4.0] [--v-thresh 20.0]
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ── Phase 1: 파싱 + 단순화 + 소실점 ─────────────────────────────────────
|
||||
|
||||
def load_polygons(label_path, W, H, class_ids=None, class_names=None):
|
||||
"""
|
||||
AnyLabeling JSON 또는 YOLO .txt 자동 감지 후 파싱.
|
||||
Returns: (픽셀 좌표 폴리곤 리스트, 절대 shapes[] 인덱스 리스트)
|
||||
절대 인덱스는 AnyLabeling JSON shapes[] 배열 인덱스와 동일.
|
||||
"""
|
||||
import json as _json
|
||||
if not label_path.exists():
|
||||
raise FileNotFoundError(f"라벨 파일을 찾을 수 없습니다: {label_path}")
|
||||
|
||||
if label_path.suffix.lower() == ".json":
|
||||
data = _json.loads(label_path.read_text(encoding="utf-8"))
|
||||
shapes = data.get("shapes", [])
|
||||
polys, abs_indices = [], []
|
||||
for abs_idx, s in enumerate(shapes):
|
||||
if class_names is not None and s.get("label", "") not in class_names:
|
||||
continue
|
||||
pts = np.array([[float(p[0]), float(p[1])] for p in s.get("points", [])],
|
||||
dtype=np.float32)
|
||||
if len(pts) >= 3:
|
||||
polys.append(pts)
|
||||
abs_indices.append(abs_idx)
|
||||
return polys, abs_indices
|
||||
|
||||
# YOLO .txt
|
||||
polys, abs_indices = [], []
|
||||
for abs_idx, line in enumerate(label_path.read_text(encoding="utf-8").splitlines()):
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
if class_ids is not None and int(parts[0]) not in class_ids:
|
||||
continue
|
||||
coords = list(map(float, parts[1:]))
|
||||
pts = np.array(
|
||||
[[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if len(pts) >= 3:
|
||||
polys.append(pts)
|
||||
abs_indices.append(abs_idx)
|
||||
return polys, abs_indices
|
||||
|
||||
|
||||
def smooth_polygon(pts, epsilon):
|
||||
"""approxPolyDP로 노이즈 경계 → 직선 위주 단순화."""
|
||||
approx = cv2.approxPolyDP(pts.astype(np.int32), epsilon, closed=True)
|
||||
result = approx.reshape(-1, 2).astype(np.float32)
|
||||
return result if len(result) >= 3 else pts.astype(np.float32)
|
||||
|
||||
|
||||
def _minrect(pts):
|
||||
"""cv2.minAreaRect 래퍼 (int32 변환 포함)."""
|
||||
return cv2.minAreaRect(pts.astype(np.int32))
|
||||
|
||||
|
||||
def _long_axis_angle(pts):
|
||||
"""minAreaRect 장축 각도 (degrees) 반환."""
|
||||
(cx, cy), (rw, rh), angle = _minrect(pts)
|
||||
return angle if rw >= rh else angle + 90, cx, cy, max(rw, rh), min(rw, rh)
|
||||
|
||||
|
||||
def _radial_cos_sim(pts, img_cx, img_cy):
|
||||
"""VP 후보 bootstrap용 기존 radial cos_sim."""
|
||||
long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts)
|
||||
if short_side < 1 or long_side / short_side < 1.3:
|
||||
return 0.0
|
||||
lx = np.cos(np.radians(long_angle_deg))
|
||||
ly = np.sin(np.radians(long_angle_deg))
|
||||
rdx, rdy = cx - img_cx, cy - img_cy
|
||||
n = (rdx ** 2 + rdy ** 2) ** 0.5
|
||||
return abs(lx * rdx / n + ly * rdy / n) if n > 1 else 0.0
|
||||
|
||||
|
||||
def compute_vanishing_point(polys):
|
||||
"""
|
||||
후보 폴리곤 장축 선분들의 최소자승 교점 (소실점) 계산.
|
||||
각 장축 선분: 법선 n=(-dy, dx), 방정식 -dy*x + dx*y = -dy*cx + dx*cy
|
||||
"""
|
||||
A_rows, b_vals = [], []
|
||||
for pts in polys:
|
||||
long_angle_deg, cx, cy, _, _ = _long_axis_angle(pts)
|
||||
dx = np.cos(np.radians(long_angle_deg))
|
||||
dy = np.sin(np.radians(long_angle_deg))
|
||||
A_rows.append([-dy, dx])
|
||||
b_vals.append(-dy * cx + dx * cy)
|
||||
vp, *_ = np.linalg.lstsq(np.array(A_rows), np.array(b_vals), rcond=None)
|
||||
return float(vp[0]), float(vp[1])
|
||||
|
||||
|
||||
def _estimate_vp_iterative(polys, seed_indices, v_thresh, h_max_diff, vp_min_len,
|
||||
x_horiz_thresh=10.0, max_iter=6):
|
||||
"""초기 후보에서 반복 정제 VP 추정. Returns (vp_x, vp_y, n_v, orients, adiffs)."""
|
||||
n = len(polys)
|
||||
orients = ['?'] * n
|
||||
adiffs = [90.0] * n
|
||||
if len(seed_indices) < 2:
|
||||
return None, None, 0, orients, adiffs
|
||||
vp_x, vp_y = compute_vanishing_point([polys[i] for i in seed_indices])
|
||||
for _ in range(max_iter):
|
||||
for i, pts in enumerate(polys):
|
||||
orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff,
|
||||
x_horiz_thresh)
|
||||
v_cands = [i for i in range(n)
|
||||
if orients[i] == 'V' and _long_axis_angle(polys[i])[3] > vp_min_len]
|
||||
if len(v_cands) < 3:
|
||||
break
|
||||
nx, ny = compute_vanishing_point([polys[i] for i in v_cands])
|
||||
shift = ((nx - vp_x) ** 2 + (ny - vp_y) ** 2) ** 0.5
|
||||
vp_x, vp_y = nx, ny
|
||||
if shift < 5.0:
|
||||
for i, pts in enumerate(polys):
|
||||
orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff,
|
||||
x_horiz_thresh)
|
||||
break
|
||||
return vp_x, vp_y, orients.count('V'), orients, adiffs
|
||||
|
||||
|
||||
# ── Phase 2: 동적 V/H 분류 ──────────────────────────────────────────────
|
||||
|
||||
def classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff=75.0, x_horiz_thresh=10.0):
|
||||
"""
|
||||
소실점 기준 V/H 분류.
|
||||
- 이미지 절대 수평(X축 ±x_horiz_thresh°) AND AR≥4 → '?' (레일·전선 등)
|
||||
- diff < v_thresh → V (기둥)
|
||||
- v_thresh ≤ diff < h_max_diff → H (빔)
|
||||
- diff ≥ h_max_diff → '?'
|
||||
Returns: (orient, angle_diff_deg)
|
||||
"""
|
||||
long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts)
|
||||
if short_side < 1 or long_side / short_side < 1.3:
|
||||
return '?', 90.0
|
||||
|
||||
# 절대 수평 제외: X축 ±x_horiz_thresh° 이내 + AR≥4 (레일·전선 등 가느다란 수평체)
|
||||
abs_from_horiz = long_angle_deg % 180.0
|
||||
if abs_from_horiz > 90.0:
|
||||
abs_from_horiz = 180.0 - abs_from_horiz
|
||||
if abs_from_horiz < x_horiz_thresh and long_side / short_side >= 4.0:
|
||||
return '?', 90.0
|
||||
|
||||
# VP 기준 상대 각도 분류
|
||||
exp_angle = np.degrees(np.arctan2(vp_y - cy, vp_x - cx))
|
||||
diff = abs(long_angle_deg - exp_angle) % 180.0
|
||||
if diff > 90.0:
|
||||
diff = 180.0 - diff
|
||||
if diff < v_thresh:
|
||||
return 'V', diff
|
||||
if diff < h_max_diff:
|
||||
return 'H', diff
|
||||
return '?', diff
|
||||
|
||||
|
||||
# ── Phase 3: 폴리곤 접촉/교차 기반 그룹핑 ───────────────────────────────
|
||||
|
||||
def connectivity_groups(polys, orients, margin=30):
|
||||
"""
|
||||
폴리곤 bbox가 margin px 이내로 닿거나 교차하면 같은 그룹 (H/V 구분 없음).
|
||||
Union-Find로 연결된 폴리곤들을 묶은 뒤, 각 그룹 내에서 H/V 목록 분리.
|
||||
Returns: list of {'id': int, 'H': [idx,...], 'V': [idx,...]}
|
||||
"""
|
||||
n = len(polys)
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
while parent[x] != x:
|
||||
parent[x] = parent[parent[x]]
|
||||
x = parent[x]
|
||||
return x
|
||||
|
||||
def union(a, b):
|
||||
ra, rb = find(a), find(b)
|
||||
if ra != rb:
|
||||
parent[ra] = rb
|
||||
|
||||
# 폴리곤 bbox 미리 계산
|
||||
bboxes = []
|
||||
for pts in polys:
|
||||
bboxes.append((pts[:, 0].min(), pts[:, 1].min(),
|
||||
pts[:, 0].max(), pts[:, 1].max()))
|
||||
|
||||
# bbox가 margin 이내로 닿거나 교차하면 union
|
||||
for i in range(n):
|
||||
ax0, ay0, ax1, ay1 = bboxes[i]
|
||||
for j in range(i + 1, n):
|
||||
bx0, by0, bx1, by1 = bboxes[j]
|
||||
if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and
|
||||
ay1 + margin >= by0 and by1 + margin >= ay0):
|
||||
union(i, j)
|
||||
|
||||
# 연결 컴포넌트 수집
|
||||
from collections import defaultdict
|
||||
comp = defaultdict(list)
|
||||
for i in range(n):
|
||||
comp[find(i)].append(i)
|
||||
|
||||
groups = []
|
||||
for gid, members in enumerate(comp.values(), 1):
|
||||
h_list = sorted(i for i in members if orients[i] == 'H')
|
||||
v_list = sorted(i for i in members if orients[i] == 'V')
|
||||
groups.append({'id': gid, 'H': h_list, 'V': v_list})
|
||||
|
||||
# 면적 내림차순으로 ID 재부여 (큰 그룹이 G1)
|
||||
for g in groups:
|
||||
g['area'] = sum(cv2.contourArea(polys[i].astype(np.int32))
|
||||
for i in g['H'] + g['V'])
|
||||
groups.sort(key=lambda x: x['area'], reverse=True)
|
||||
for gid, g in enumerate(groups, 1):
|
||||
g['id'] = gid
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
# ── Phase 4: 라멘 구조 판정 ─────────────────────────────────────────────
|
||||
|
||||
def _cluster_polys(indices, polys, margin=60):
|
||||
"""
|
||||
인접 폴리곤을 클러스터링 → 물리적 객체(기둥) 단위 반환.
|
||||
Returns: list of [idx, ...] clusters
|
||||
"""
|
||||
if not indices:
|
||||
return []
|
||||
n = len(indices)
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
while parent[x] != x:
|
||||
parent[x] = parent[parent[x]]
|
||||
x = parent[x]
|
||||
return x
|
||||
|
||||
bboxes = [(polys[i][:, 0].min(), polys[i][:, 1].min(),
|
||||
polys[i][:, 0].max(), polys[i][:, 1].max()) for i in indices]
|
||||
for a in range(n):
|
||||
ax0, ay0, ax1, ay1 = bboxes[a]
|
||||
for b in range(a + 1, n):
|
||||
bx0, by0, bx1, by1 = bboxes[b]
|
||||
if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and
|
||||
ay1 + margin >= by0 and by1 + margin >= ay0):
|
||||
ra, rb = find(a), find(b)
|
||||
if ra != rb:
|
||||
parent[ra] = rb
|
||||
|
||||
from collections import defaultdict
|
||||
clusters = defaultdict(list)
|
||||
for i in range(n):
|
||||
clusters[find(i)].append(indices[i])
|
||||
return list(clusters.values())
|
||||
|
||||
|
||||
def judge_raamen(group, polys, W, center_ratio=0.2, v_cluster_margin=60):
|
||||
"""
|
||||
라멘 구조 판정.
|
||||
- 빔 기준: 그룹 내 가장 큰 H 폴리곤
|
||||
- 기둥 수: V 폴리곤을 근접 클러스터링한 클러스터 수
|
||||
- 빔 x 범위(±50%) 밖의 V 클러스터는 잡폴리곤으로 무시
|
||||
Returns: ('RAAMEN' | 'RAAMEN_OCCLUDED' | 'PARTIAL' | '', n_poles)
|
||||
"""
|
||||
hs, vs = group['H'], group['V']
|
||||
|
||||
# H 없음: 중앙 영역이면 V/H 분류 자체가 신뢰 불가 (기둥·빔 각도 수렴)
|
||||
# → 2개 이상의 V 폴리곤이 중앙에 모여 있으면 RAAMEN_CENTER 처리
|
||||
if not hs:
|
||||
if len(vs) >= 2:
|
||||
all_v_pts = np.vstack([polys[i] for i in vs])
|
||||
gcx = float(all_v_pts[:, 0].mean())
|
||||
if abs(gcx - W / 2) < W * center_ratio:
|
||||
return 'RAAMEN_CENTER', len(vs)
|
||||
return '', 0
|
||||
|
||||
# 큰 H 폴리곤 최대 2개를 빔 기준으로 사용 (2nd가 1st 면적의 50% 이상이면 포함)
|
||||
h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True)
|
||||
main_hs = [h_by_area[0]]
|
||||
if len(h_by_area) > 1:
|
||||
a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32))
|
||||
a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32))
|
||||
if a1 >= a0 * 0.5:
|
||||
main_hs.append(h_by_area[1])
|
||||
hx0 = int(min(polys[i][:, 0].min() for i in main_hs))
|
||||
hx1 = int(max(polys[i][:, 0].max() for i in main_hs))
|
||||
hcx = (hx0 + hx1) / 2.0
|
||||
is_center = abs(hcx - W / 2) < W * center_ratio
|
||||
span = hx1 - hx0
|
||||
|
||||
# V 폴리곤 클러스터링 → 기둥 단위
|
||||
pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin)
|
||||
|
||||
# 기둥은 빔 양 끝단(좌 35% / 우 35%)에만 존재. 중앙부 클러스터는 부속물로 제외.
|
||||
x_tol = max(span * 0.5, 50)
|
||||
left_zone = hx0 + span * 0.35 # 좌끝단 경계
|
||||
right_zone = hx1 - span * 0.35 # 우끝단 경계
|
||||
valid_cxs = []
|
||||
for cluster in pole_clusters:
|
||||
ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster]))
|
||||
in_range = hx0 - x_tol <= ccx <= hx1 + x_tol
|
||||
in_end_zone = ccx <= left_zone or ccx >= right_zone
|
||||
if in_range and in_end_zone:
|
||||
valid_cxs.append(ccx)
|
||||
|
||||
n_poles = len(valid_cxs)
|
||||
|
||||
if n_poles >= 2:
|
||||
lcx, rcx = min(valid_cxs), max(valid_cxs)
|
||||
if lcx <= hx0 + span * 0.4 and rcx >= hx1 - span * 0.4:
|
||||
return 'RAAMEN', n_poles
|
||||
return 'PARTIAL', n_poles
|
||||
|
||||
if n_poles == 1:
|
||||
return ('RAAMEN_OCCLUDED', 1) if not is_center else ('', 1)
|
||||
|
||||
return '', 0
|
||||
|
||||
|
||||
def _merge_poly_hull(indices, polys):
|
||||
"""여러 폴리곤 점들을 합쳐 Convex Hull 좌표 반환 (AnyLabeling points 형식)."""
|
||||
all_pts = np.vstack([polys[i] for i in indices])
|
||||
hull = cv2.convexHull(all_pts.astype(np.int32))
|
||||
return [[float(p[0][0]), float(p[0][1])] for p in hull]
|
||||
|
||||
|
||||
def group_detail(group, polys, W, center_ratio=0.2, v_cluster_margin=60):
|
||||
"""
|
||||
라멘 그룹의 세부 구성 분석.
|
||||
Returns dict: main_h, junk_h, valid_pole_clusters, attach_clusters
|
||||
"""
|
||||
hs, vs = group['H'], group['V']
|
||||
if not hs:
|
||||
return {'main_h': None, 'junk_h': [], 'valid_pole_clusters': [], 'attach_clusters': []}
|
||||
|
||||
h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True)
|
||||
main_hs = [h_by_area[0]]
|
||||
if len(h_by_area) > 1:
|
||||
a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32))
|
||||
a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32))
|
||||
if a1 >= a0 * 0.5:
|
||||
main_hs.append(h_by_area[1])
|
||||
main_h = main_hs[0] # JSON 출력용 대표 빔
|
||||
junk_h = [i for i in hs if i not in main_hs]
|
||||
|
||||
hx0 = int(min(polys[i][:, 0].min() for i in main_hs))
|
||||
hx1 = int(max(polys[i][:, 0].max() for i in main_hs))
|
||||
span = hx1 - hx0
|
||||
|
||||
pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin)
|
||||
x_tol = max(span * 0.5, 50)
|
||||
left_zone = hx0 + span * 0.35
|
||||
right_zone = hx1 - span * 0.35
|
||||
|
||||
valid_pole_clusters, attach_clusters = [], []
|
||||
for cluster in pole_clusters:
|
||||
ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster]))
|
||||
in_range = hx0 - x_tol <= ccx <= hx1 + x_tol
|
||||
in_end_zone = ccx <= left_zone or ccx >= right_zone
|
||||
if in_range and in_end_zone:
|
||||
valid_pole_clusters.append(cluster)
|
||||
elif in_range:
|
||||
attach_clusters.append(cluster)
|
||||
|
||||
return {
|
||||
'main_h': main_h,
|
||||
'main_hs': main_hs,
|
||||
'junk_h': junk_h,
|
||||
'valid_pole_clusters': valid_pole_clusters,
|
||||
'attach_clusters': attach_clusters,
|
||||
}
|
||||
|
||||
|
||||
# ── 시각화 상수 ─────────────────────────────────────────────────────────
|
||||
|
||||
_VH_COLOR = {
|
||||
'V': (255, 80, 0), # 주황 (수직 기둥)
|
||||
'H': ( 0, 80, 255), # 파란 (수평 빔)
|
||||
'?': (140, 140, 140), # 회색 (미분류)
|
||||
}
|
||||
_RAAMEN_COLOR = {
|
||||
'RAAMEN': ( 0, 255, 0), # 초록
|
||||
'RAAMEN_CENTER': ( 0, 255, 255), # 노랑 (중앙 영역, H/V 분류 불신뢰)
|
||||
'RAAMEN_OCCLUDED': ( 0, 165, 255), # 주황 (가림/부분 검출)
|
||||
'PARTIAL': (128, 128, 128), # 회색
|
||||
}
|
||||
|
||||
|
||||
# ── 메인 렌더링 ─────────────────────────────────────────────────────────
|
||||
|
||||
def render(image_path, label_path, output_path, args,
|
||||
class_ids=None, class_names=None, epsilon=4.0, v_thresh=20.0,
|
||||
h_max_diff=75.0, vp_min_ar=2.5, vp_min_len=80.0, vp_outer_ratio=0.2,
|
||||
x_horiz_thresh=10.0):
|
||||
|
||||
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
H, W = img.shape[:2]
|
||||
|
||||
# ── Phase 1 ──────────────────────────────────────────────────────────
|
||||
raw_polys, poly_abs_idx = load_polygons(label_path, W, H, class_ids, class_names)
|
||||
polys = [smooth_polygon(p, epsilon) for p in raw_polys]
|
||||
print(f" {len(polys)}개 폴리곤 파싱 (epsilon={epsilon})")
|
||||
|
||||
img_cx, img_cy = W / 2.0, H / 2.0
|
||||
|
||||
# ── Phase 1+2: 두 가지 VP 시드 방식으로 시도 → V 폴리곤이 더 많은 VP 채택 ──
|
||||
elong_idx = []
|
||||
for i, pts in enumerate(polys):
|
||||
la, cx, cy, long_side, short_side = _long_axis_angle(pts)
|
||||
if short_side > 0 and long_side / short_side > vp_min_ar and long_side > vp_min_len:
|
||||
elong_idx.append(i)
|
||||
|
||||
# 시드 A: 지배적 장축 각도 클러스터 (이미지 방향 비의존)
|
||||
seeds_A = []
|
||||
if len(elong_idx) >= 2:
|
||||
avals_all = [(i, _long_axis_angle(polys[i])[0] % 180.0) for i in elong_idx]
|
||||
hist, edges = np.histogram([a for _, a in avals_all], bins=12, range=(0.0, 180.0))
|
||||
pk = int(np.argmax(hist))
|
||||
peak_c = (edges[pk] + edges[pk + 1]) / 2.0
|
||||
seeds_A = [i for i, a in avals_all
|
||||
if min(abs(a - peak_c), 180.0 - abs(a - peak_c)) < 20.0]
|
||||
|
||||
# 시드 B: radial cos_sim (이미지 중심 방사 방향 정렬)
|
||||
seeds_B = []
|
||||
for i in elong_idx:
|
||||
_, cx, cy, _, _ = _long_axis_angle(polys[i])
|
||||
dist = ((cx - img_cx) ** 2 + (cy - img_cy) ** 2) ** 0.5
|
||||
if dist > min(W, H) * vp_outer_ratio and _radial_cos_sim(polys[i], img_cx, img_cy) > 0.5:
|
||||
seeds_B.append(i)
|
||||
|
||||
# 두 시드 모두 시도 → V가 더 많이 나오는 VP 채택
|
||||
best_vp_x, best_vp_y = img_cx, -H * 3.0
|
||||
best_n_v = 0
|
||||
orients, adiffs = ['?'] * len(polys), [90.0] * len(polys)
|
||||
for label, seeds in [('지배각', seeds_A), ('radial', seeds_B)]:
|
||||
if len(seeds) < 2:
|
||||
continue
|
||||
vx, vy, nv, ors, ads = _estimate_vp_iterative(
|
||||
polys, seeds, v_thresh, h_max_diff, vp_min_len, x_horiz_thresh)
|
||||
print(f" VP [{label}]: ({vx:.0f}, {vy:.0f}) V={nv}")
|
||||
if nv > best_n_v:
|
||||
best_vp_x, best_vp_y, best_n_v = vx, vy, nv
|
||||
orients, adiffs = ors, ads
|
||||
vp_x, vp_y = best_vp_x, best_vp_y
|
||||
print(f" → 채택 VP: ({vp_x:.1f}, {vp_y:.1f})")
|
||||
|
||||
print(f"\n [V/H 분류] threshold=±{v_thresh}° h_max=±{h_max_diff}°")
|
||||
for i in range(len(polys)):
|
||||
print(f" poly {i:>2d}: {orients[i]} diff={adiffs[i]:.1f}°")
|
||||
print(f" V:{orients.count('V')} H:{orients.count('H')} ?:{orients.count('?')}")
|
||||
|
||||
# ── Phase 3 ──────────────────────────────────────────────────────────
|
||||
groups = connectivity_groups(polys, orients, margin=args.margin)
|
||||
print(f"\n [그룹핑] 연결 컴포넌트 {len(groups)}개 (margin={args.margin}px)")
|
||||
for g in groups:
|
||||
print(f" G{g['id']}: H={g['H']} V={g['V']}")
|
||||
|
||||
# ── Phase 4 ──────────────────────────────────────────────────────────
|
||||
print(f"\n [라멘 판정]")
|
||||
for g in groups:
|
||||
verdict, n_poles = judge_raamen(g, polys, W)
|
||||
g['verdict'] = verdict
|
||||
g['n_poles'] = n_poles
|
||||
pole_str = f"{n_poles}poles" if n_poles else "-"
|
||||
print(f" G{g['id']}: H={g['H']} V={g['V']} → {verdict or '-':18s} ({pole_str})")
|
||||
|
||||
valid_items = [g for g in groups if g['verdict']]
|
||||
valid_items.sort(key=lambda x: x['area'], reverse=True)
|
||||
|
||||
print(f"\n [최종 라멘 객체] {len(valid_items)}개 (면적순)")
|
||||
for g in valid_items:
|
||||
print(f" G{g['id']}: H={g['H']} V={g['V']} → {g['verdict']:18s} ({g['n_poles']}poles) Area={g['area']:,.0f}")
|
||||
|
||||
# 최소 면적 필터링
|
||||
if args.min_group_area > 0:
|
||||
before = len(valid_items)
|
||||
valid_items = [g for g in valid_items if g['area'] >= args.min_group_area]
|
||||
print(f" [필터] 최소 면적 {args.min_group_area} 미만 제거: {before} → {len(valid_items)}개")
|
||||
|
||||
# 모든 그룹: 그룹 내 최하단 꼭짓점 포함 폴리곤이 H이면 → V로 재분류
|
||||
for g in valid_items:
|
||||
all_idxs = g['H'] + g['V']
|
||||
bottom_vi = max(all_idxs, key=lambda i: polys[i][:, 1].max())
|
||||
if bottom_vi in g['H']:
|
||||
g['H'].remove(bottom_vi)
|
||||
g['V'].append(bottom_vi)
|
||||
orients[bottom_vi] = 'V'
|
||||
# RAAMEN_CENTER (H 없는 그룹): 최하단=V, 나머지=H로 display 재조정
|
||||
if not g['H']:
|
||||
for vi in g['V']:
|
||||
orients[vi] = 'V' if vi == bottom_vi else 'H'
|
||||
|
||||
# ── 시각화 ───────────────────────────────────────────────────────────
|
||||
|
||||
# 1. 폴리곤 V/H 색상 반투명 오버레이
|
||||
for i, (pts, orient) in enumerate(zip(polys, orients)):
|
||||
color = _VH_COLOR[orient]
|
||||
pts_i = pts.astype(np.int32)
|
||||
ov = img.copy()
|
||||
cv2.fillPoly(ov, [pts_i], color)
|
||||
cv2.addWeighted(ov, 0.25, img, 0.75, 0, img)
|
||||
cv2.polylines(img, [pts_i], True, color, 2)
|
||||
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
|
||||
lbl = f"{i}{orient}"
|
||||
cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4)
|
||||
cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
|
||||
|
||||
# 2. VP 시드 폴리곤에 청록 테두리 (채택된 시드셋 재구성)
|
||||
for i in seeds_A + [i for i in seeds_B if i not in seeds_A]:
|
||||
cv2.polylines(img, [polys[i].astype(np.int32)], True, (0, 220, 220), 3)
|
||||
|
||||
# 3. 소실점 표시 (이미지 내부: 원, 외부: 방향 화살표)
|
||||
vp_ix, vp_iy = int(vp_x), int(vp_y)
|
||||
if 0 <= vp_ix < W and 0 <= vp_iy < H:
|
||||
cv2.circle(img, (vp_ix, vp_iy), 20, (0, 255, 255), 3)
|
||||
cv2.putText(img, "VP", (vp_ix + 25, vp_iy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
|
||||
else:
|
||||
ax_s, ay_s = W // 2, H // 5
|
||||
dx, dy = vp_x - img_cx, vp_y - img_cy
|
||||
n = (dx ** 2 + dy ** 2) ** 0.5
|
||||
if n > 0:
|
||||
ax_e = max(0, min(W - 1, int(ax_s + dx / n * 80)))
|
||||
ay_e = max(0, min(H - 1, int(ay_s + dy / n * 80)))
|
||||
cv2.arrowedLine(img, (ax_s, ay_s), (ax_e, ay_e), (0, 255, 255), 3)
|
||||
cv2.putText(img, f"VP({vp_ix},{vp_iy})", (ax_s + 5, ay_s - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
|
||||
|
||||
# 4. 라멘 그룹 강조: 바운딩 박스 + 그룹 ID + 판정 라벨
|
||||
for g in valid_items:
|
||||
verdict = g['verdict']
|
||||
color = _RAAMEN_COLOR[verdict]
|
||||
all_idx = g['H'] + g['V']
|
||||
all_pts = np.vstack([polys[i] for i in all_idx]).astype(np.int32)
|
||||
x0 = all_pts[:, 0].min() - 15; y0 = all_pts[:, 1].min() - 15
|
||||
x1 = all_pts[:, 0].max() + 15; y1 = all_pts[:, 1].max() + 15
|
||||
cv2.rectangle(img, (x0, y0), (x1, y1), color, 4)
|
||||
lbl = f"G{g['id']} {verdict} ({g['n_poles']}poles)"
|
||||
cv2.putText(img, lbl, (x0, y0 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 0), 6)
|
||||
cv2.putText(img, lbl, (x0, y0 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)
|
||||
|
||||
# 이미지 저장
|
||||
scale = min(1.0, 4096 / max(H, W))
|
||||
if scale < 1.0:
|
||||
img = cv2.resize(img, (int(W * scale), int(H * scale)))
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path))
|
||||
print(f"\n → {output_path}")
|
||||
|
||||
# AnyLabeling JSON 저장
|
||||
import json
|
||||
|
||||
def _shape(label, points, gid, desc):
|
||||
return {"label": label, "score": None, "points": points,
|
||||
"group_id": gid, "description": desc,
|
||||
"shape_type": "polygon", "flags": None}
|
||||
|
||||
shapes = []
|
||||
for g in valid_items:
|
||||
gid = g['id']
|
||||
verdict = g['verdict']
|
||||
desc_base = f"G{gid} {verdict}"
|
||||
|
||||
def abs_ids(rel_list):
|
||||
"""상대 폴리곤 인덱스 → 절대 JSON shapes[] 인덱스 변환."""
|
||||
return sorted(poly_abs_idx[i] for i in rel_list)
|
||||
|
||||
if not g['H']:
|
||||
# RAAMEN_CENTER: 최하단 꼭짓점 포함 폴리곤 = 기둥, 나머지 = 빔
|
||||
bottom_vi = max(g['V'], key=lambda i: polys[i][:, 1].max())
|
||||
for vi in g['V']:
|
||||
label = "raamen_pole" if vi == bottom_vi else "raamen_beam"
|
||||
shapes.append(_shape(label,
|
||||
[[float(p[0]), float(p[1])] for p in polys[vi]],
|
||||
gid, f"{desc_base} shape#{poly_abs_idx[vi]}"))
|
||||
else:
|
||||
# 일반 RAAMEN: V = 기둥, H = 빔
|
||||
for vi in g['V']:
|
||||
shapes.append(_shape("raamen_pole",
|
||||
[[float(p[0]), float(p[1])] for p in polys[vi]],
|
||||
gid, f"{desc_base} pole shape#{poly_abs_idx[vi]}"))
|
||||
for hi in g['H']:
|
||||
shapes.append(_shape("raamen_beam",
|
||||
[[float(p[0]), float(p[1])] for p in polys[hi]],
|
||||
gid, f"{desc_base} beam shape#{poly_abs_idx[hi]}"))
|
||||
|
||||
anylabel_json = {
|
||||
"version": "3.3.9",
|
||||
"flags": {},
|
||||
"shapes": shapes,
|
||||
"imagePath": image_path.name,
|
||||
"imageData": None,
|
||||
"imageHeight": H,
|
||||
"imageWidth": W,
|
||||
}
|
||||
json_path = output_path.with_suffix(".json")
|
||||
json_path.write_text(json.dumps(anylabel_json, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
print(f" → {json_path}")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--image", required=True)
|
||||
ap.add_argument("--label", required=True)
|
||||
ap.add_argument("--output", required=True)
|
||||
ap.add_argument("--class-ids", default="",
|
||||
help="포함할 클래스 ID, 콤마 구분 (.txt 전용)")
|
||||
ap.add_argument("--class-names", default="catenary_pole",
|
||||
help="포함할 클래스 이름, 콤마 구분 (.json 전용, 기본: 'catenary_pole')")
|
||||
ap.add_argument("--epsilon", type=float, default=4.0,
|
||||
help="approxPolyDP epsilon (기본 4.0px)")
|
||||
ap.add_argument("--v-thresh", type=float, default=20.0,
|
||||
help="V/H 분류 각도 임계값 degrees (기본 20°)")
|
||||
ap.add_argument("--h-max-diff", type=float, default=75.0,
|
||||
help="H(빔) 최대 각도 diff; 이 이상은 레일/전선 등으로 제외 (기본 75°)")
|
||||
ap.add_argument("--x-horiz-thresh", type=float, default=10.0,
|
||||
help="X축 절대 수평 제외 임계값 degrees; AR≥4 AND 이 각도 이내 → 제외 (기본 10°)")
|
||||
ap.add_argument("--margin", type=int, default=30,
|
||||
help="폴리곤 접촉 판정 margin px (기본 30)")
|
||||
ap.add_argument("--min-group-area", type=float, default=0,
|
||||
help="라멘 그룹의 최소 면적 합계 (기본 0)")
|
||||
args = ap.parse_args()
|
||||
|
||||
class_ids = ({int(x) for x in args.class_ids.split(',') if x.strip()}
|
||||
if args.class_ids else None)
|
||||
class_names = ({x.strip() for x in args.class_names.split(',') if x.strip()}
|
||||
if args.class_names else None)
|
||||
|
||||
out = Path(args.output)
|
||||
folder = out.parent / out.stem # e.g. output/0004_test
|
||||
if folder.exists():
|
||||
n = 1
|
||||
while (out.parent / f"{out.stem}_{n}").exists():
|
||||
n += 1
|
||||
folder = out.parent / f"{out.stem}_{n}"
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
out = folder / out.name # e.g. output/0004_test/0004_test.jpg
|
||||
print(f" [출력 폴더] {folder}")
|
||||
|
||||
render(Path(args.image), Path(args.label), out, args,
|
||||
class_ids=class_ids, class_names=class_names,
|
||||
epsilon=args.epsilon, v_thresh=args.v_thresh,
|
||||
h_max_diff=args.h_max_diff, x_horiz_thresh=args.x_horiz_thresh)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
503
tools/labeling_server.py
Normal file
503
tools/labeling_server.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
Control Box 라벨링 서버 v2 — 전체 이미지 + bbox 오버레이 방식
|
||||
사용법: python tools/labeling_server.py \
|
||||
--json "data/역사이미지/slope/DJI_20260306113838_0004_everything.json"
|
||||
[--reset] DB 초기화
|
||||
브라우저: http://localhost:7001
|
||||
"""
|
||||
import argparse, json, sqlite3, sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2, numpy as np
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, Response
|
||||
import uvicorn
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
DB = ROOT / "labels" / "labeling.db"
|
||||
MIN_VOTES = 3
|
||||
TRUE_RATIO = 0.6
|
||||
MAX_DIM = 3000 # served image max dimension (px)
|
||||
|
||||
CONTROL_LABELS = {
|
||||
"small dark object on ballast", "manhole cover", "trackside enclosure",
|
||||
"dark rectangle on ground", "small square box", "small metal cabinet",
|
||||
"small cube on gravel", "compact trackside junction box",
|
||||
"small square gray metal box beside rail",
|
||||
"small near-square electrical enclosure on the ground",
|
||||
}
|
||||
|
||||
app = FastAPI()
|
||||
_img_data: bytes = None
|
||||
_orig_w = _orig_h = _disp_w = _disp_h = 0
|
||||
|
||||
|
||||
# ── DB ────────────────────────────────────────────────────────────────────────
|
||||
def get_db():
|
||||
con = sqlite3.connect(str(DB))
|
||||
con.row_factory = sqlite3.Row
|
||||
return con
|
||||
|
||||
|
||||
def init_db(candidates: list):
|
||||
DB.parent.mkdir(parents=True, exist_ok=True)
|
||||
con = get_db()
|
||||
con.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS candidates (
|
||||
id INTEGER PRIMARY KEY,
|
||||
json_idx INTEGER,
|
||||
label TEXT,
|
||||
score REAL,
|
||||
bbox TEXT,
|
||||
image_path TEXT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS votes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
candidate_id INTEGER,
|
||||
user TEXT,
|
||||
vote INTEGER,
|
||||
ts DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(candidate_id, user)
|
||||
);
|
||||
""")
|
||||
if con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0] > 0:
|
||||
n = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0]
|
||||
print(f"DB 기존 데이터 유지. candidates={n}")
|
||||
con.close()
|
||||
return
|
||||
for c in candidates:
|
||||
con.execute(
|
||||
"INSERT INTO candidates(json_idx,label,score,bbox,image_path) VALUES(?,?,?,?,?)",
|
||||
(c["json_idx"], c["label"], c["score"], json.dumps(c["bbox"]), c["image_path"]),
|
||||
)
|
||||
con.commit()
|
||||
print(f"DB 등록: {len(candidates)}개")
|
||||
con.close()
|
||||
|
||||
|
||||
# ── 후보 로드 ─────────────────────────────────────────────────────────────────
|
||||
def load_candidates(json_path: Path) -> list:
|
||||
data = json.loads(json_path.read_text(encoding="utf-8"))
|
||||
stem = json_path.stem.replace("_everything", "")
|
||||
img_path = next(
|
||||
(json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png")
|
||||
if (json_path.parent / f"{stem}{ext}").exists()),
|
||||
None,
|
||||
)
|
||||
if img_path is None:
|
||||
raise FileNotFoundError(f"원본 이미지 없음: {json_path.parent / stem}.*")
|
||||
|
||||
candidates = []
|
||||
for idx, seg in enumerate(data.get("segments", [])):
|
||||
label = seg.get("label", "").strip()
|
||||
if label not in CONTROL_LABELS:
|
||||
continue
|
||||
x0, y0, x1, y1 = [float(v) for v in seg["bbox"]]
|
||||
if (x1 - x0) < 4 or (y1 - y0) < 4:
|
||||
continue
|
||||
candidates.append({
|
||||
"json_idx": idx,
|
||||
"label": label,
|
||||
"score": float(seg.get("score", 0)),
|
||||
"bbox": [x0, y0, x1, y1],
|
||||
"image_path": str(img_path),
|
||||
})
|
||||
|
||||
segs_total = len(data.get("segments", []))
|
||||
print(f"필터링: {segs_total}개 → {len(candidates)}개 control_box 후보")
|
||||
return candidates
|
||||
|
||||
|
||||
# ── 이미지 준비 (서버 시작 시 1회) ───────────────────────────────────────────
|
||||
def prepare_image(img_path: Path):
|
||||
global _img_data, _orig_w, _orig_h, _disp_w, _disp_h
|
||||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError(f"이미지 로드 실패: {img_path}")
|
||||
_orig_h, _orig_w = img.shape[:2]
|
||||
scale = min(1.0, MAX_DIM / max(_orig_w, _orig_h))
|
||||
_disp_w, _disp_h = int(_orig_w * scale), int(_orig_h * scale)
|
||||
if scale < 1.0:
|
||||
img = cv2.resize(img, (_disp_w, _disp_h), interpolation=cv2.INTER_LANCZOS4)
|
||||
_, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
_img_data = bytes(enc)
|
||||
print(f"이미지 준비: {_orig_w}×{_orig_h} → {_disp_w}×{_disp_h}")
|
||||
|
||||
|
||||
# ── API ───────────────────────────────────────────────────────────────────────
|
||||
@app.get("/image")
|
||||
async def serve_image():
|
||||
return Response(content=_img_data, media_type="image/jpeg")
|
||||
|
||||
|
||||
@app.get("/api/image_info")
|
||||
async def api_image_info():
|
||||
return {"orig_w": _orig_w, "orig_h": _orig_h, "disp_w": _disp_w, "disp_h": _disp_h}
|
||||
|
||||
|
||||
@app.get("/api/candidates")
|
||||
async def api_candidates(user: str = ""):
|
||||
con = get_db()
|
||||
rows = con.execute("""
|
||||
SELECT c.id, c.bbox, c.label, c.score, v.vote
|
||||
FROM candidates c
|
||||
LEFT JOIN votes v ON c.id=v.candidate_id AND v.user=?
|
||||
""", (user,)).fetchall()
|
||||
con.close()
|
||||
return {"candidates": [
|
||||
{
|
||||
"id": r["id"],
|
||||
"bbox": json.loads(r["bbox"]),
|
||||
"label": r["label"],
|
||||
"score": round(r["score"], 2),
|
||||
"voted": r["vote"] is not None,
|
||||
"is_true": r["vote"] == 1 if r["vote"] is not None else None,
|
||||
} for r in rows
|
||||
]}
|
||||
|
||||
|
||||
@app.post("/api/vote")
|
||||
async def api_vote(data: dict):
|
||||
user = data.get("user", "").strip()
|
||||
all_ids = data.get("all_ids", [])
|
||||
true_ids = set(data.get("true_ids", []))
|
||||
if not user or not all_ids:
|
||||
return {"ok": False, "error": "파라미터 오류"}
|
||||
con = get_db()
|
||||
for cid in all_ids:
|
||||
con.execute(
|
||||
"INSERT OR REPLACE INTO votes(candidate_id,user,vote) VALUES(?,?,?)",
|
||||
(cid, user, 1 if cid in true_ids else 0),
|
||||
)
|
||||
con.commit()
|
||||
con.close()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
async def api_stats():
|
||||
con = get_db()
|
||||
total = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0]
|
||||
voted3 = con.execute(f"""
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT candidate_id FROM votes GROUP BY candidate_id HAVING COUNT(*) >= {MIN_VOTES}
|
||||
)
|
||||
""").fetchone()[0]
|
||||
users = con.execute("SELECT COUNT(DISTINCT user) FROM votes").fetchone()[0]
|
||||
tvotes = con.execute("SELECT COUNT(*) FROM votes").fetchone()[0]
|
||||
con.close()
|
||||
return {"total": total, "voted3plus": voted3, "users": users, "total_votes": tvotes}
|
||||
|
||||
|
||||
@app.post("/api/export")
|
||||
async def api_export():
|
||||
con = get_db()
|
||||
confirmed = con.execute(f"""
|
||||
SELECT c.id, c.bbox, c.image_path,
|
||||
SUM(v.vote) AS true_v, COUNT(v.id) AS total_v
|
||||
FROM candidates c JOIN votes v ON c.id=v.candidate_id
|
||||
GROUP BY c.id
|
||||
HAVING total_v >= {MIN_VOTES} AND (CAST(true_v AS REAL)/total_v) >= {TRUE_RATIO}
|
||||
""").fetchall()
|
||||
con.close()
|
||||
|
||||
by_image = defaultdict(list)
|
||||
for row in confirmed:
|
||||
by_image[row["image_path"]].append(row)
|
||||
|
||||
out_dir = ROOT / "labels" / "yolo_export"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
for img_path, rows in by_image.items():
|
||||
buf = np.fromfile(img_path, dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
H, W = img.shape[:2]
|
||||
txt = out_dir / (Path(img_path).stem + ".txt")
|
||||
with open(txt, "w") as f:
|
||||
for row in rows:
|
||||
x0, y0, x1, y1 = json.loads(row["bbox"])
|
||||
f.write(f"0 {((x0+x1)/2)/W:.6f} {((y0+y1)/2)/H:.6f} "
|
||||
f"{(x1-x0)/W:.6f} {(y1-y0)/H:.6f}\n")
|
||||
count += 1
|
||||
return {"ok": True, "labels": count, "dir": str(out_dir)}
|
||||
|
||||
|
||||
# ── HTML ──────────────────────────────────────────────────────────────────────
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index():
|
||||
return HTML
|
||||
|
||||
|
||||
HTML = r"""<!DOCTYPE html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Control Box 라벨링</title>
|
||||
<style>
|
||||
*{box-sizing:border-box;margin:0;padding:0}
|
||||
body{font-family:'Segoe UI',sans-serif;background:#1a1a2e;color:#e0e0e0;display:flex;flex-direction:column;height:100vh;overflow:hidden}
|
||||
#toolbar{background:#16213e;border-bottom:1px solid #0f3460;padding:8px 14px;display:flex;align-items:center;gap:10px;flex-shrink:0;flex-wrap:wrap}
|
||||
h1{color:#e94560;font-size:1rem;white-space:nowrap}
|
||||
input[type=text]{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0;padding:5px 10px;border-radius:4px;font-size:.85rem;width:130px}
|
||||
input[type=text]:focus{outline:none;border-color:#e94560}
|
||||
.btn{background:#e94560;color:#fff;border:none;padding:6px 12px;border-radius:5px;cursor:pointer;font-size:.82rem;font-weight:700;white-space:nowrap}
|
||||
.btn:hover{background:#c73652}
|
||||
.btn-sec{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0}
|
||||
.btn-sec:hover{background:#1a5299}
|
||||
#stats{font-size:.75rem;color:#777;white-space:nowrap}
|
||||
#legend{display:flex;gap:10px;align-items:center;font-size:.72rem;color:#888}
|
||||
.leg{display:inline-flex;align-items:center;gap:3px}
|
||||
.lb{width:12px;height:12px;border:2px solid;border-radius:2px;flex-shrink:0}
|
||||
#wrap{flex:1;position:relative;overflow:hidden;cursor:crosshair}
|
||||
canvas{display:block}
|
||||
#tip{position:absolute;background:rgba(0,0,0,.85);color:#fff;font-size:.7rem;padding:3px 8px;border-radius:3px;pointer-events:none;display:none;max-width:300px}
|
||||
#hint{position:absolute;bottom:8px;left:50%;transform:translateX(-50%);font-size:.72rem;color:#444;pointer-events:none;white-space:nowrap}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="toolbar">
|
||||
<h1>🎯 Control Box</h1>
|
||||
<input type="text" id="uname" placeholder="이름/사번" autocomplete="off">
|
||||
<button class="btn" onclick="init()">시작</button>
|
||||
<button class="btn btn-sec" onclick="fitView()">맞춤(F)</button>
|
||||
<span id="stats">—</span>
|
||||
<div id="legend">
|
||||
<span class="leg"><span class="lb" style="border-color:#ffaa00"></span>미투표</span>
|
||||
<span class="leg"><span class="lb" style="border-color:#00cc66"></span>컨트롤박스 ✓</span>
|
||||
<span class="leg"><span class="lb" style="border-color:#cc4444"></span>아님 ✗</span>
|
||||
<span class="leg"><span class="lb" style="border-color:#444"></span>타인투표완료</span>
|
||||
</div>
|
||||
<div style="flex:1"></div>
|
||||
<button class="btn btn-sec" onclick="doExport()" style="font-size:.75rem">YOLO 내보내기</button>
|
||||
</div>
|
||||
<div id="wrap">
|
||||
<canvas id="cv"></canvas>
|
||||
<div id="tip"></div>
|
||||
<div id="hint">휠=줌 | 드래그=이동 | 클릭=YES/NO 토글 | F=맞춤</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const cv = document.getElementById('cv');
|
||||
const ctx = cv.getContext('2d');
|
||||
const wrap = document.getElementById('wrap');
|
||||
|
||||
let user = '', cands = [], imgScale = 1;
|
||||
let origW = 0, origH = 0, dispW = 0, dispH = 0;
|
||||
let sc = 1, ox = 0, oy = 0;
|
||||
let isDrag = false, dragX = 0, dragY = 0, moved = false;
|
||||
const img = new Image();
|
||||
|
||||
function resizeCv() {
|
||||
cv.width = wrap.clientWidth;
|
||||
cv.height = wrap.clientHeight;
|
||||
render();
|
||||
}
|
||||
window.addEventListener('resize', resizeCv);
|
||||
resizeCv();
|
||||
|
||||
async function init() {
|
||||
user = document.getElementById('uname').value.trim();
|
||||
if (!user) { alert('이름 입력 필요'); return; }
|
||||
const info = await fetch('/api/image_info').then(r => r.json());
|
||||
origW = info.orig_w; origH = info.orig_h;
|
||||
dispW = info.disp_w; dispH = info.disp_h;
|
||||
imgScale = dispW / origW;
|
||||
img.onload = () => { fitView(); loadCands(); };
|
||||
img.src = '/image?' + Date.now();
|
||||
}
|
||||
|
||||
async function loadCands() {
|
||||
const d = await fetch('/api/candidates?user=' + encodeURIComponent(user)).then(r => r.json());
|
||||
cands = d.candidates.map(c => ({...c, _sel: c.voted ? c.is_true : undefined}));
|
||||
updateStats();
|
||||
render();
|
||||
}
|
||||
|
||||
function updateStats() {
|
||||
const myVotes = cands.filter(c => c._sel !== undefined).length;
|
||||
const yes = cands.filter(c => c._sel === true).length;
|
||||
const othersVoted = cands.filter(c => c.voted && c._sel === undefined).length;
|
||||
document.getElementById('stats').textContent =
|
||||
`후보 ${cands.length}개 | 내투표 ${myVotes}개 (YES:${yes}) | 타인완료 ${othersVoted}개`;
|
||||
}
|
||||
|
||||
function fitView() {
|
||||
if (!img.naturalWidth) return;
|
||||
const ww = wrap.clientWidth, wh = wrap.clientHeight;
|
||||
sc = Math.min(ww / dispW, wh / dispH) * 0.97;
|
||||
ox = (ww - dispW * sc) / 2;
|
||||
oy = (wh - dispH * sc) / 2;
|
||||
render();
|
||||
}
|
||||
|
||||
function render() {
|
||||
const W = cv.width, H = cv.height;
|
||||
ctx.fillStyle = '#0d0d1e';
|
||||
ctx.fillRect(0, 0, W, H);
|
||||
if (!img.naturalWidth) return;
|
||||
|
||||
ctx.save();
|
||||
ctx.translate(ox, oy);
|
||||
ctx.scale(sc, sc);
|
||||
ctx.drawImage(img, 0, 0, dispW, dispH);
|
||||
|
||||
const lw = Math.max(1, 2 / sc);
|
||||
for (const c of cands) {
|
||||
const [x0, y0, x1, y1] = c.bbox.map(v => v * imgScale);
|
||||
const w = x1 - x0, h = y1 - y0;
|
||||
let color, fill = null;
|
||||
if (c._sel === true) { color = '#00cc66'; fill = 'rgba(0,204,102,0.15)'; }
|
||||
else if (c._sel === false) { color = '#cc4444'; fill = 'rgba(204,68,68,0.15)'; }
|
||||
else if (c.voted) { color = '#444444'; }
|
||||
else { color = c.score > 0.5 ? '#ffaa00' : '#ff7700aa'; }
|
||||
|
||||
ctx.lineWidth = lw;
|
||||
ctx.strokeStyle = color;
|
||||
if (fill) {
|
||||
ctx.fillStyle = fill;
|
||||
ctx.fillRect(x0, y0, w, h);
|
||||
}
|
||||
ctx.strokeRect(x0, y0, w, h);
|
||||
}
|
||||
ctx.restore();
|
||||
}
|
||||
|
||||
// ── wheel zoom ────────────────────────────────────────────────────────────────
|
||||
wrap.addEventListener('wheel', e => {
|
||||
e.preventDefault();
|
||||
const rect = wrap.getBoundingClientRect();
|
||||
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
|
||||
const f = e.deltaY < 0 ? 1.15 : 1/1.15;
|
||||
ox = mx - (mx - ox) * f;
|
||||
oy = my - (my - oy) * f;
|
||||
sc *= f;
|
||||
render();
|
||||
}, {passive: false});
|
||||
|
||||
// ── drag pan ──────────────────────────────────────────────────────────────────
|
||||
wrap.addEventListener('mousedown', e => {
|
||||
if (e.button !== 0) return;
|
||||
isDrag = true; moved = false;
|
||||
dragX = e.clientX; dragY = e.clientY;
|
||||
wrap.style.cursor = 'grabbing';
|
||||
});
|
||||
window.addEventListener('mousemove', e => {
|
||||
if (isDrag) {
|
||||
const dx = e.clientX - dragX, dy = e.clientY - dragY;
|
||||
if (Math.abs(dx) + Math.abs(dy) > 3) moved = true;
|
||||
if (moved) { ox += dx; oy += dy; dragX = e.clientX; dragY = e.clientY; render(); }
|
||||
}
|
||||
// tooltip
|
||||
const c = hitTest(e);
|
||||
const tip = document.getElementById('tip');
|
||||
if (c) {
|
||||
const rect = wrap.getBoundingClientRect();
|
||||
tip.style.display = 'block';
|
||||
tip.style.left = (e.clientX - rect.left + 14) + 'px';
|
||||
tip.style.top = (e.clientY - rect.top + 4) + 'px';
|
||||
const status = c._sel === true ? '✅ YES' : c._sel === false ? '❌ NO' : c.voted ? '⬜ 타인완료' : '❓ 미투표';
|
||||
tip.textContent = `[${c.id}] ${c.label} (${c.score}) — ${status}`;
|
||||
} else {
|
||||
tip.style.display = 'none';
|
||||
}
|
||||
});
|
||||
window.addEventListener('mouseup', e => {
|
||||
if (!isDrag) return;
|
||||
wrap.style.cursor = 'crosshair';
|
||||
isDrag = false;
|
||||
if (!moved) handleClick(e);
|
||||
});
|
||||
|
||||
// ── hit test ──────────────────────────────────────────────────────────────────
|
||||
function imgCoords(e) {
|
||||
const rect = wrap.getBoundingClientRect();
|
||||
return [(e.clientX - rect.left - ox) / sc / imgScale,
|
||||
(e.clientY - rect.top - oy) / sc / imgScale];
|
||||
}
|
||||
|
||||
function hitTest(e) {
|
||||
const [ix, iy] = imgCoords(e);
|
||||
let best = null, bestArea = Infinity;
|
||||
for (const c of cands) {
|
||||
const [x0, y0, x1, y1] = c.bbox;
|
||||
if (ix >= x0 && ix <= x1 && iy >= y0 && iy <= y1) {
|
||||
const a = (x1-x0)*(y1-y0);
|
||||
if (a < bestArea) { bestArea = a; best = c; }
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
async function handleClick(e) {
|
||||
if (!user) return;
|
||||
const c = hitTest(e);
|
||||
if (!c) return;
|
||||
// Toggle: undefined→YES→NO→YES...
|
||||
c._sel = c._sel === true ? false : true;
|
||||
render();
|
||||
updateStats();
|
||||
// Save immediately
|
||||
await fetch('/api/vote', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({user, all_ids: [c.id], true_ids: c._sel ? [c.id] : []})
|
||||
});
|
||||
c.voted = true;
|
||||
c.is_true = c._sel;
|
||||
}
|
||||
|
||||
// ── keyboard ──────────────────────────────────────────────────────────────────
|
||||
window.addEventListener('keydown', e => {
|
||||
if (e.key === 'f' || e.key === 'F') fitView();
|
||||
});
|
||||
|
||||
async function doExport() {
|
||||
const r = await fetch('/api/export', {method: 'POST'}).then(r => r.json());
|
||||
alert(`완료: ${r.labels}개 라벨\n경로: ${r.dir}`);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
ap.add_argument("--json", required=True, help="*_everything.json 경로")
|
||||
ap.add_argument("--port", type=int, default=7001)
|
||||
ap.add_argument("--reset", action="store_true", help="DB 초기화 후 재로드")
|
||||
args = ap.parse_args()
|
||||
|
||||
json_path = Path(args.json)
|
||||
if not json_path.exists():
|
||||
print(f"JSON 없음: {json_path}"); sys.exit(1)
|
||||
|
||||
if args.reset and DB.exists():
|
||||
DB.unlink()
|
||||
print("DB 초기화 완료.")
|
||||
|
||||
print("후보 로딩 중...")
|
||||
candidates = load_candidates(json_path)
|
||||
init_db(candidates)
|
||||
|
||||
# Prepare image
|
||||
stem = json_path.stem.replace("_everything", "")
|
||||
img_path = next(
|
||||
(json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png")
|
||||
if (json_path.parent / f"{stem}{ext}").exists()),
|
||||
None,
|
||||
)
|
||||
if img_path is None:
|
||||
print(f"원본 이미지 없음: {json_path.parent / stem}.*"); sys.exit(1)
|
||||
prepare_image(img_path)
|
||||
|
||||
print(f"\n라벨링 서버 시작: http://localhost:{args.port}")
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
tools/merge_tiles_vis.py
Normal file
90
tools/merge_tiles_vis.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
6×8 (또는 임의 그리드) 타일 라벨을 원본 이미지 좌표로 병합해 시각화.
|
||||
|
||||
사용:
|
||||
python tools/merge_tiles_vis.py \
|
||||
--orig data/역사이미지/slope/DJI_20260306113838_0004.JPG \
|
||||
--labels output/autolabel/tile6x8/labels \
|
||||
--output output/autolabel/tile6x8/merged_vis.jpg \
|
||||
--cols 6 --rows 8
|
||||
"""
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
CLASS_NAMES = ["catenary_pole", "bracket"]
|
||||
CLASS_COLORS = [(0, 200, 255), (255, 130, 0)]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--orig", required=True)
|
||||
parser.add_argument("--labels", required=True)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--cols", type=int, default=6)
|
||||
parser.add_argument("--rows", type=int, default=8)
|
||||
parser.add_argument("--alpha", type=float, default=0.3)
|
||||
args = parser.parse_args()
|
||||
|
||||
buf = np.fromfile(args.orig, dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
H, W = img.shape[:2]
|
||||
tw, th = W // args.cols, H // args.rows
|
||||
|
||||
label_dir = Path(args.labels)
|
||||
counts = [0] * len(CLASS_NAMES)
|
||||
|
||||
for r in range(args.rows):
|
||||
for c in range(args.cols):
|
||||
label_file = label_dir / f"tile_r{r+1:02d}_c{c+1:02d}.txt"
|
||||
if not label_file.exists():
|
||||
continue
|
||||
x0 = c * tw
|
||||
y0 = r * th
|
||||
tile_w = tw if c < args.cols - 1 else W - x0
|
||||
tile_h = th if r < args.rows - 1 else H - y0
|
||||
|
||||
text = label_file.read_text(encoding="utf-8").strip()
|
||||
if not text:
|
||||
continue
|
||||
for line in text.splitlines():
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
cls_id = int(parts[0])
|
||||
coords = list(map(float, parts[1:]))
|
||||
pts = np.array(
|
||||
[[coords[i] * tile_w + x0, coords[i + 1] * tile_h + y0]
|
||||
for i in range(0, len(coords), 2)],
|
||||
dtype=np.int32,
|
||||
)
|
||||
color = CLASS_COLORS[cls_id % len(CLASS_COLORS)]
|
||||
overlay = img.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, args.alpha, img, 1 - args.alpha, 0, img)
|
||||
cv2.polylines(img, [pts], True, color, 3)
|
||||
if cls_id < len(counts):
|
||||
counts[cls_id] += 1
|
||||
|
||||
# 범례
|
||||
for i, name in enumerate(CLASS_NAMES):
|
||||
y = 30 + i * 50
|
||||
cv2.rectangle(img, (15, y - 20), (55, y + 10), CLASS_COLORS[i], -1)
|
||||
cv2.putText(img, f"{name}: {counts[i]}",
|
||||
(65, y), cv2.FONT_HERSHEY_SIMPLEX, 1.5, CLASS_COLORS[i], 3)
|
||||
|
||||
total = sum(counts)
|
||||
print(f"총 {total}개 " + ", ".join(f"{CLASS_NAMES[i]}={counts[i]}" for i in range(len(counts))))
|
||||
|
||||
# 4096px 이하로 미리보기 저장
|
||||
scale = min(1.0, 4096 / max(H, W))
|
||||
vis = cv2.resize(img, (int(W * scale), int(H * scale)))
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imencode(out_path.suffix, vis)[1].tofile(str(out_path))
|
||||
print(f"저장: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
tools/post_merge_poles.py
Normal file
154
tools/post_merge_poles.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""post_merge_poles.py — detect_all_objects.py 출력 JSON에서 catenary_pole 병합.
|
||||
|
||||
detecting은 한 번만, 병합 파라미터 조정 시 이 스크립트만 재실행.
|
||||
|
||||
Usage:
|
||||
python tools/post_merge_poles.py INPUT.json [--x-overlap 0.30] [--y-gap 150]
|
||||
python tools/post_merge_poles.py INPUT.json --inplace
|
||||
python tools/post_merge_poles.py INPUT.json --output OUTPUT.json
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _poly_orient(points: list, H: int, W: int) -> str:
|
||||
"""폴리곤 장축 방향 판별.
|
||||
|
||||
V: 장축이 이미지 중심 radial 방향 정렬(cos_sim > 0.7) → 세로 기둥
|
||||
H: 직교 → 수평 빔
|
||||
?: aspect ratio < 1.3
|
||||
"""
|
||||
pts = np.array(points, dtype=np.float32)
|
||||
rect = cv2.minAreaRect(pts)
|
||||
(rx, ry), (rw, rh), angle = rect
|
||||
if min(rw, rh) < 1:
|
||||
return '?'
|
||||
ar = max(rw, rh) / min(rw, rh)
|
||||
if ar < 1.3:
|
||||
return '?'
|
||||
long_angle_deg = angle if rw >= rh else angle + 90
|
||||
lx = float(np.cos(np.radians(long_angle_deg)))
|
||||
ly = float(np.sin(np.radians(long_angle_deg)))
|
||||
img_cx, img_cy = W / 2.0, H / 2.0
|
||||
rdx, rdy = rx - img_cx, ry - img_cy
|
||||
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
|
||||
if radial_norm < 1:
|
||||
return '?'
|
||||
rdx, rdy = rdx / radial_norm, rdy / radial_norm
|
||||
cos_sim = abs(lx * rdx + ly * rdy)
|
||||
return 'V' if cos_sim > 0.7 else 'H'
|
||||
|
||||
|
||||
def merge_poles(shapes: list, H: int, W: int,
|
||||
x_overlap_thresh: float = 0.30,
|
||||
y_gap_thresh: int = 150) -> list:
|
||||
"""V+V 조합 전철주만 타일 경계 병합."""
|
||||
if len(shapes) <= 1:
|
||||
return shapes
|
||||
|
||||
orients = [_poly_orient(s["points"], H, W) for s in shapes]
|
||||
v_count = sum(1 for o in orients if o == 'V')
|
||||
h_count = sum(1 for o in orients if o == 'H')
|
||||
print(f" orient: V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}")
|
||||
|
||||
def get_bbox(s):
|
||||
xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
def x_overlap_ratio(b1, b2):
|
||||
ox = min(b1[2], b2[2]) - max(b1[0], b2[0])
|
||||
ux = max(b1[2], b2[2]) - min(b1[0], b2[0])
|
||||
return ox / ux if ux > 0 else 0.0
|
||||
|
||||
def y_gap(b1, b2):
|
||||
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
|
||||
|
||||
def merge_two(s1, s2):
|
||||
mask = np.zeros((H, W), dtype=np.uint8)
|
||||
for s in (s1, s2):
|
||||
cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return s1
|
||||
c = max(contours, key=cv2.contourArea)
|
||||
eps = 0.002 * cv2.arcLength(c, True)
|
||||
approx = cv2.approxPolyDP(c, eps, True)
|
||||
merged = dict(s1)
|
||||
merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0)))
|
||||
return merged
|
||||
|
||||
merged_flags = [False] * len(shapes)
|
||||
result = []
|
||||
merged_count = 0
|
||||
for i in range(len(shapes)):
|
||||
if merged_flags[i]:
|
||||
continue
|
||||
cur = shapes[i]
|
||||
cur_ori = orients[i]
|
||||
cb = get_bbox(cur)
|
||||
for j in range(i + 1, len(shapes)):
|
||||
if merged_flags[j]:
|
||||
continue
|
||||
if cur_ori != 'V' or orients[j] != 'V':
|
||||
continue
|
||||
jb = get_bbox(shapes[j])
|
||||
if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh:
|
||||
cur = merge_two(cur, shapes[j])
|
||||
cur_ori = 'V'
|
||||
cb = get_bbox(cur)
|
||||
merged_flags[j] = True
|
||||
merged_count += 1
|
||||
result.append(cur)
|
||||
print(f" 병합: {len(shapes)} → {len(result)}개 ({merged_count}쌍 합침)")
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input", help="AnyLabeling JSON 파일")
|
||||
ap.add_argument("--output", default=None, help="출력 JSON (기본: INPUT_merged.json)")
|
||||
ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기")
|
||||
ap.add_argument("--x-overlap", type=float, default=0.30, help="x-range 겹침 비율 임계값")
|
||||
ap.add_argument("--y-gap", type=int, default=150, help="y-range 간격 임계값 (px)")
|
||||
ap.add_argument("--label", default="catenary_pole", help="병합 대상 라벨명")
|
||||
args = ap.parse_args()
|
||||
|
||||
src = Path(args.input)
|
||||
if not src.exists():
|
||||
print(f"파일 없음: {src}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
data = json.loads(src.read_text(encoding="utf-8"))
|
||||
shapes = data.get("shapes", [])
|
||||
iW = data.get("imageWidth", 0)
|
||||
iH = data.get("imageHeight", 0)
|
||||
if iW == 0 or iH == 0:
|
||||
print("imageWidth/imageHeight 정보 없음", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
target = [s for s in shapes if s.get("label") == args.label]
|
||||
others = [s for s in shapes if s.get("label") != args.label]
|
||||
print(f"{args.label}: {len(target)}개 → 병합 처리")
|
||||
|
||||
merged = merge_poles(target, iH, iW, args.x_overlap, args.y_gap)
|
||||
data["shapes"] = others + merged
|
||||
|
||||
if args.inplace:
|
||||
dst = src
|
||||
elif args.output:
|
||||
dst = Path(args.output)
|
||||
else:
|
||||
dst = src.with_stem(src.stem + "_merged")
|
||||
|
||||
dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"저장: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
370
tools/rail_alignment_fit.py
Normal file
370
tools/rail_alignment_fit.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
철도 평면 선형 피팅 모듈
|
||||
스켈레톤 좌표점 → 직선/원곡선/완화곡선 분류 및 피팅 → 매끄러운 폴리라인
|
||||
|
||||
선형 구성요소:
|
||||
1. 직선 (Straight) : y = ax + b
|
||||
2. 원곡선 (Circular) : R = 상수, Taubin 원 피팅
|
||||
3. 완화곡선 (Transition) : y = x³/(6RL), 3차포물선 (KR 설계기준)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import savgol_filter
|
||||
from scipy.linalg import eig as scipy_eig
|
||||
|
||||
|
||||
# ── 1. 곡률 계산 ──────────────────────────────────────────────────────────
|
||||
|
||||
def compute_curvature(pts, smooth_window=15):
|
||||
"""
|
||||
길이 가중 방향벡터를 이용한 곡률 계산
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pts : (N, 2) array 세계 좌표 (미터)
|
||||
smooth_window : int Savitzky-Golay 스무딩 윈도우 (홀수)
|
||||
|
||||
Returns
|
||||
-------
|
||||
curvatures : (N,) 곡률 κ [1/m]
|
||||
arc_lengths : (N,) 각 점의 누적 호 길이 [m]
|
||||
angles : (N,) 방향각 [rad], unwrap 처리됨
|
||||
"""
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
n = len(pts)
|
||||
if n < 3:
|
||||
return np.zeros(n), np.zeros(n), np.zeros(n)
|
||||
|
||||
# 세그먼트 벡터 및 길이
|
||||
segs = np.diff(pts, axis=0) # (n-1, 2)
|
||||
lens = np.linalg.norm(segs, axis=1) # (n-1,)
|
||||
lens = np.maximum(lens, 1e-12)
|
||||
|
||||
# 단위 방향벡터
|
||||
dirs = segs / lens[:, np.newaxis]
|
||||
|
||||
# 가중 방향벡터: 인접 세그먼트 길이를 가중치로 평균
|
||||
directions = np.empty((n, 2))
|
||||
directions[0] = dirs[0]
|
||||
directions[-1] = dirs[-1]
|
||||
for i in range(1, n - 1):
|
||||
w1, w2 = lens[i - 1], lens[i]
|
||||
directions[i] = (w1 * dirs[i - 1] + w2 * dirs[i]) / (w1 + w2)
|
||||
|
||||
# 정규화
|
||||
norms = np.linalg.norm(directions, axis=1, keepdims=True)
|
||||
directions /= np.maximum(norms, 1e-12)
|
||||
|
||||
# 방향각 (연속성 보장)
|
||||
angles = np.arctan2(directions[:, 1], directions[:, 0])
|
||||
angles = np.unwrap(angles)
|
||||
|
||||
# 누적 호 길이
|
||||
arc_lengths = np.zeros(n)
|
||||
arc_lengths[1:] = np.cumsum(lens)
|
||||
|
||||
# 곡률 κ = dθ/ds (중앙차분)
|
||||
curvatures = np.zeros(n)
|
||||
for i in range(1, n - 1):
|
||||
dtheta = angles[i + 1] - angles[i - 1]
|
||||
ds = arc_lengths[i + 1] - arc_lengths[i - 1]
|
||||
if ds > 1e-10:
|
||||
curvatures[i] = dtheta / ds
|
||||
curvatures[0] = curvatures[1]
|
||||
curvatures[-1] = curvatures[-2]
|
||||
|
||||
# Savitzky-Golay 스무딩 (노이즈 억제)
|
||||
w = smooth_window
|
||||
if n >= w and w >= 5:
|
||||
w = w if w % 2 == 1 else w - 1
|
||||
curvatures = savgol_filter(curvatures, w, 3)
|
||||
|
||||
return curvatures, arc_lengths, angles
|
||||
|
||||
|
||||
# ── 2. 구간 분류 ──────────────────────────────────────────────────────────
|
||||
|
||||
def segment_by_curvature(curvatures, arc_lengths,
|
||||
kappa_straight=0.001, # 1/1000m → R > 1000m 직선 처리
|
||||
cv_thresh=0.35, # 변동계수(표준편차/평균) 임계값
|
||||
min_seg_len=3.0): # 최소 구간 길이 [m]
|
||||
"""
|
||||
곡률 패턴으로 구간 분류
|
||||
|
||||
Returns
|
||||
-------
|
||||
segments : list of (start_idx, end_idx, type_str)
|
||||
type_str ∈ {'straight', 'circular', 'transition'}
|
||||
"""
|
||||
n = len(curvatures)
|
||||
kappa_abs = np.abs(curvatures)
|
||||
win = max(5, n // 15)
|
||||
|
||||
def classify_window(k_arr):
|
||||
km = k_arr.mean()
|
||||
ks = k_arr.std()
|
||||
if km < kappa_straight:
|
||||
return 'straight'
|
||||
cv = ks / (km + 1e-12)
|
||||
if cv < cv_thresh:
|
||||
return 'circular'
|
||||
return 'transition'
|
||||
|
||||
# 각 점의 로컬 타입
|
||||
local_type = []
|
||||
for i in range(n):
|
||||
lo = max(0, i - win // 2)
|
||||
hi = min(n, i + win // 2 + 1)
|
||||
local_type.append(classify_window(kappa_abs[lo:hi]))
|
||||
|
||||
# 연속된 같은 타입 구간 합치기
|
||||
segments = []
|
||||
cur_type = local_type[0]
|
||||
cur_start = 0
|
||||
for i in range(1, n):
|
||||
if local_type[i] != cur_type:
|
||||
segments.append((cur_start, i - 1, cur_type))
|
||||
cur_type = local_type[i]
|
||||
cur_start = i
|
||||
segments.append((cur_start, n - 1, cur_type))
|
||||
|
||||
# 너무 짧은 구간은 인접 구간에 병합
|
||||
merged = True
|
||||
while merged and len(segments) > 1:
|
||||
merged = False
|
||||
new_segs = []
|
||||
i = 0
|
||||
while i < len(segments):
|
||||
s, e, t = segments[i]
|
||||
seg_len = arc_lengths[e] - arc_lengths[s]
|
||||
if seg_len < min_seg_len and len(segments) > 1:
|
||||
if i == 0 and i + 1 < len(segments):
|
||||
ns, ne, nt = segments[i + 1]
|
||||
segments[i + 1] = (s, ne, nt)
|
||||
i += 1
|
||||
merged = True
|
||||
elif new_segs:
|
||||
new_segs[-1] = (new_segs[-1][0], e, new_segs[-1][2])
|
||||
merged = True
|
||||
else:
|
||||
new_segs.append((s, e, t))
|
||||
else:
|
||||
new_segs.append((s, e, t))
|
||||
i += 1
|
||||
if merged:
|
||||
segments = new_segs
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
# ── 3. 개별 구간 피팅 ────────────────────────────────────────────────────
|
||||
|
||||
def fit_straight(pts, spacing=0.5):
|
||||
"""PCA 직선 피팅 → 균등 샘플 점 목록"""
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
center = pts.mean(axis=0)
|
||||
_, _, Vt = np.linalg.svd(pts - center)
|
||||
direction = Vt[0]
|
||||
t = (pts - center) @ direction
|
||||
t_min, t_max = t.min(), t.max()
|
||||
n_pts = max(2, int((t_max - t_min) / spacing) + 1)
|
||||
ts = np.linspace(t_min, t_max, n_pts)
|
||||
return [(center + ti * direction).tolist() for ti in ts]
|
||||
|
||||
|
||||
def taubin_circle_fit(pts):
|
||||
"""Taubin 방법으로 안정적 원 피팅 → (cx, cy, R) 또는 None"""
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
mean = pts.mean(axis=0)
|
||||
x = pts[:, 0] - mean[0]
|
||||
y = pts[:, 1] - mean[1]
|
||||
z = x**2 + y**2
|
||||
|
||||
Mxx = (x * x).mean()
|
||||
Myy = (y * y).mean()
|
||||
Mxy = (x * y).mean()
|
||||
Mxz = (x * z).mean()
|
||||
Myz = (y * z).mean()
|
||||
Mzz = (z * z).mean()
|
||||
|
||||
Mz = Mxx + Myy
|
||||
Cov_xy = Mxx * Myy - Mxy**2
|
||||
A3 = 4 * Mz
|
||||
A2 = -3 * Mz**2 - Mzz
|
||||
A1 = Mzz * Mz + 4 * Cov_xy * Mz - Mxz**2 - Myz**2 - Mz**3
|
||||
A0 = Mxz**2 * Myy + Myz**2 * Mxx - Mzz * Cov_xy - 2 * Mxz * Myz * Mxy + Mz**2 * Cov_xy
|
||||
|
||||
try:
|
||||
roots = np.roots([A3, A2, A1, A0])
|
||||
real_roots = roots[np.isreal(roots)].real
|
||||
if len(real_roots) == 0:
|
||||
return None
|
||||
xn = max(real_roots)
|
||||
yn = (xn**2 - Mz) * xn + A0 / A3
|
||||
det = xn**2 - xn * Mz + Cov_xy
|
||||
if abs(det) < 1e-12:
|
||||
return None
|
||||
xcen = (Mxz * (Myy - xn) - Myz * Mxy) / (2 * det) + mean[0]
|
||||
ycen = (Myz * (Mxx - xn) - Mxz * Mxy) / (2 * det) + mean[1]
|
||||
R = np.sqrt(np.mean((pts[:, 0] - xcen)**2 + (pts[:, 1] - ycen)**2))
|
||||
return xcen, ycen, R
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def fit_circular(pts, spacing=0.5):
|
||||
"""원곡선 피팅 → 호 위의 균등 샘플 점 목록"""
|
||||
result = taubin_circle_fit(pts)
|
||||
if result is None:
|
||||
return fit_straight(pts, spacing)
|
||||
|
||||
cx, cy, R = result
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
angs = np.arctan2(pts[:, 1] - cy, pts[:, 0] - cx)
|
||||
angs = np.unwrap(angs)
|
||||
a_start, a_end = angs[0], angs[-1]
|
||||
|
||||
arc_len = abs(a_end - a_start) * R
|
||||
n_pts = max(2, int(arc_len / spacing) + 1)
|
||||
thetas = np.linspace(a_start, a_end, n_pts)
|
||||
return [(cx + R * np.cos(t), cy + R * np.sin(t)) for t in thetas]
|
||||
|
||||
|
||||
def fit_transition_cubic(pts, R_adj, spacing=0.5):
|
||||
"""
|
||||
3차포물선 완화곡선 피팅 y = x³ / (6·R·L)
|
||||
로컬 좌표계(시작점, 시작방향 기준)에서 피팅 후 세계 좌표로 역변환
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pts : (N, 2) 세계 좌표
|
||||
R_adj : 인접 원곡선 반경 [m] (없으면 곡률 역수로 추정)
|
||||
"""
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
origin = pts[0].copy()
|
||||
|
||||
# 로컬 좌표계 축 (PCA 주방향)
|
||||
_, _, Vt = np.linalg.svd(pts - origin)
|
||||
axis_x = Vt[0]
|
||||
axis_y = np.array([-axis_x[1], axis_x[0]])
|
||||
|
||||
local = pts - origin
|
||||
lx = local @ axis_x
|
||||
ly = local @ axis_y
|
||||
|
||||
L = lx.max()
|
||||
if L < 1e-5 or R_adj < 1e-5:
|
||||
return pts.tolist()
|
||||
|
||||
# 균등 샘플
|
||||
n_pts = max(2, int(L / spacing) + 1)
|
||||
xs = np.linspace(0, L, n_pts)
|
||||
ys = xs**3 / (6.0 * R_adj * L)
|
||||
|
||||
return [(origin + xi * axis_x + yi * axis_y).tolist()
|
||||
for xi, yi in zip(xs, ys)]
|
||||
|
||||
|
||||
# ── 4. 전체 피팅 파이프라인 ───────────────────────────────────────────────
|
||||
|
||||
def fit_alignment(pts, spacing=0.5):
|
||||
"""
|
||||
메인 함수: 세계 좌표 점 목록 → 선형 피팅 → 매끄러운 폴리라인
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pts : list of (x, y) EPSG:5186 세계 좌표 [m]
|
||||
spacing : float 출력 점 간격 [m]
|
||||
|
||||
Returns
|
||||
-------
|
||||
smooth_pts : list of (x, y) 피팅된 매끄러운 폴리라인 점 목록
|
||||
seg_info : list of dict 각 구간 정보 (디버깅용)
|
||||
"""
|
||||
pts = np.asarray(pts, dtype=float)
|
||||
n = len(pts)
|
||||
|
||||
if n < 5:
|
||||
return pts.tolist(), []
|
||||
|
||||
# --- 곡률 계산 ---
|
||||
sw = min(21, n if n % 2 == 1 else n - 1)
|
||||
sw = max(5, sw)
|
||||
curvatures, arc_lengths, angles = compute_curvature(pts, sw)
|
||||
|
||||
# --- 구간 분류 ---
|
||||
segments = segment_by_curvature(curvatures, arc_lengths)
|
||||
|
||||
# --- 인접 원곡선 R 사전 계산 (완화곡선에서 사용) ---
|
||||
seg_R = {}
|
||||
for idx, (s, e, t) in enumerate(segments):
|
||||
if t == 'circular':
|
||||
km = np.abs(curvatures[s:e+1]).mean()
|
||||
seg_R[idx] = 1.0 / km if km > 1e-6 else 1000.0
|
||||
|
||||
smooth_pts = []
|
||||
seg_info = []
|
||||
|
||||
for idx, (s, e, seg_type) in enumerate(segments):
|
||||
seg_pts = pts[s:e+1]
|
||||
|
||||
if len(seg_pts) < 2:
|
||||
smooth_pts.extend(seg_pts.tolist())
|
||||
continue
|
||||
|
||||
km = np.abs(curvatures[s:e+1]).mean()
|
||||
R_est = 1.0 / km if km > 1e-6 else 9999.0
|
||||
|
||||
if seg_type == 'straight':
|
||||
fitted = fit_straight(seg_pts, spacing)
|
||||
elif seg_type == 'circular':
|
||||
fitted = fit_circular(seg_pts, spacing)
|
||||
else: # transition
|
||||
# 인접 원곡선 R 탐색
|
||||
R_use = R_est
|
||||
for di in [1, -1]:
|
||||
ni = idx + di
|
||||
if ni in seg_R:
|
||||
R_use = seg_R[ni]
|
||||
break
|
||||
fitted = fit_transition_cubic(seg_pts, R_use, spacing)
|
||||
|
||||
seg_len = arc_lengths[e] - arc_lengths[s]
|
||||
seg_info.append({
|
||||
'type': seg_type,
|
||||
'start_idx': s,
|
||||
'end_idx': e,
|
||||
'length_m': round(seg_len, 2),
|
||||
'R_m': round(R_est, 1),
|
||||
'n_pts_in': len(seg_pts),
|
||||
'n_pts_out': len(fitted),
|
||||
})
|
||||
|
||||
# 이음새 중복 제거
|
||||
if smooth_pts and fitted:
|
||||
smooth_pts.extend(fitted[1:])
|
||||
else:
|
||||
smooth_pts.extend(fitted)
|
||||
|
||||
return smooth_pts, seg_info
|
||||
|
||||
|
||||
# ── 5. 간단 테스트 ────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 직선 + 원곡선 + 완화곡선 합성 테스트
|
||||
t = np.linspace(0, 2*np.pi, 300)
|
||||
# 직선 구간
|
||||
straight = [(float(i)*0.5, 0.0) for i in range(40)]
|
||||
# 원곡선 구간 R=300m
|
||||
R = 300.0
|
||||
theta = np.linspace(0, np.pi/4, 60)
|
||||
cx, cy = straight[-1][0], straight[-1][1] + R
|
||||
circ = [(cx + R*np.sin(th), cy - R*np.cos(th)) for th in theta]
|
||||
pts_test = straight + circ[1:]
|
||||
|
||||
smooth, info = fit_alignment(pts_test, spacing=1.0)
|
||||
print(f"입력: {len(pts_test)}점 → 출력: {len(smooth)}점")
|
||||
for s in info:
|
||||
print(f" [{s['type']:10s}] L={s['length_m']:.1f}m R={s['R_m']:.0f}m "
|
||||
f"점: {s['n_pts_in']}→{s['n_pts_out']}")
|
||||
175
tools/rail_centerline_dxf.py
Normal file
175
tools/rail_centerline_dxf.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
2cm GSD 드론 TIF 전체에서 레일 중심선 추출 → DXF 저장
|
||||
|
||||
중심선 추출 방법:
|
||||
masks.xy 폴리곤 → PCA 주축(길이 방향) 찾기
|
||||
→ 주축 방향으로 슬라이싱 → 폭 방향 중앙값
|
||||
→ 계단 현상 없는 매끄러운 중심선
|
||||
"""
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
from ultralytics import YOLO
|
||||
from PIL import Image
|
||||
import ezdxf
|
||||
|
||||
|
||||
def polygon_to_centerline(xy_tile, spacing_px=10):
|
||||
"""
|
||||
타일 픽셀 좌표 폴리곤 → 중심선 점 목록
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xy_tile : list of (x, y) 타일 픽셀 좌표 (masks.xy × sx/sy)
|
||||
spacing_px : int 슬라이싱 간격 [픽셀] (10px = 0.2m at 2cm GSD)
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of (x, y) 중심선 타일 픽셀 좌표
|
||||
"""
|
||||
pts = np.asarray(xy_tile, dtype=float)
|
||||
if len(pts) < 4:
|
||||
return []
|
||||
|
||||
# PCA: 주축(레일 길이 방향) 탐색
|
||||
center = pts.mean(axis=0)
|
||||
_, _, Vt = np.linalg.svd(pts - center, full_matrices=False)
|
||||
v_long = Vt[0] # 길이 방향 단위벡터
|
||||
v_perp = Vt[1] # 폭 방향 단위벡터
|
||||
|
||||
local = pts - center
|
||||
t = local @ v_long # 길이 방향 좌표
|
||||
|
||||
t_min, t_max = t.min(), t.max()
|
||||
if t_max - t_min < spacing_px:
|
||||
# 너무 짧으면 무게중심 1점
|
||||
return [center.tolist()]
|
||||
|
||||
n_bins = max(2, int((t_max - t_min) / spacing_px) + 1)
|
||||
bins = np.linspace(t_min, t_max, n_bins + 1)
|
||||
|
||||
centerline = []
|
||||
s_all = local @ v_perp # 폭 방향 좌표 (전체)
|
||||
|
||||
for i in range(n_bins):
|
||||
mask = (t >= bins[i]) & (t <= bins[i + 1])
|
||||
if mask.sum() < 2:
|
||||
continue
|
||||
t_c = (bins[i] + bins[i + 1]) / 2
|
||||
s_vals = s_all[mask]
|
||||
# 폭 방향: 폴리곤 양 가장자리의 기하학적 중앙
|
||||
s_c = (s_vals.max() + s_vals.min()) / 2
|
||||
pt = center + t_c * v_long + s_c * v_perp
|
||||
centerline.append(pt.tolist())
|
||||
|
||||
return centerline
|
||||
|
||||
|
||||
# ── 설정 ──────────────────────────────────────────
|
||||
SRC_TIF = "drone_2cm/22)조치원(STA.127+570~131+300).tif"
|
||||
MODEL_PT = "runs/segment/output/yolo_train_2cm/rail_seg_v2/weights/best.pt"
|
||||
OUT_DXF = "output/rail_centerline_2cm_pca.dxf"
|
||||
TILE = 1024
|
||||
OVERLAP = 128
|
||||
STEP = TILE - OVERLAP
|
||||
CONF = 0.3
|
||||
BLACK_THR = 0.4 # 검은 픽셀 비율 임계값
|
||||
SPACING = 10 # 중심선 샘플 간격 [픽셀] 10px ≈ 0.2m
|
||||
|
||||
# ── 모델 로드 ──────────────────────────────────────
|
||||
print("모델 로드 중...")
|
||||
model = YOLO(MODEL_PT)
|
||||
|
||||
# ── DXF 초기화 ────────────────────────────────────
|
||||
doc = ezdxf.new()
|
||||
msp = doc.modelspace()
|
||||
doc.layers.add("RAIL_CENTERLINE", color=3)
|
||||
|
||||
# ── TIF 열기 ──────────────────────────────────────
|
||||
src = rasterio.open(SRC_TIF)
|
||||
W, H = src.width, src.height
|
||||
transform = src.transform
|
||||
print(f"이미지 크기: {W} x {H}")
|
||||
|
||||
|
||||
def pixel_to_world(px, py):
|
||||
"""픽셀 좌표 → 세계 좌표 (EPSG:5186)"""
|
||||
wx = transform.c + px * transform.a
|
||||
wy = transform.f + py * transform.e
|
||||
return wx, wy
|
||||
|
||||
|
||||
# ── 타일 순회 ─────────────────────────────────────
|
||||
total_polylines = 0
|
||||
xs_range = range(0, W - TILE // 2, STEP)
|
||||
ys_range = range(0, H - TILE // 2, STEP)
|
||||
total_tiles = len(xs_range) * len(ys_range)
|
||||
processed = 0
|
||||
|
||||
print(f"총 타일 수(예상): {total_tiles}, 처리 시작...")
|
||||
|
||||
for ty in ys_range:
|
||||
for tx in xs_range:
|
||||
tw = min(TILE, W - tx)
|
||||
th = min(TILE, H - ty)
|
||||
if tw < 64 or th < 64:
|
||||
continue
|
||||
|
||||
win = Window(tx, ty, tw, th)
|
||||
data = src.read([1, 2, 3], window=win)
|
||||
img_arr = np.transpose(data, (1, 2, 0)).astype(np.uint8)
|
||||
|
||||
black_ratio = np.mean(np.all(img_arr < 10, axis=2))
|
||||
if black_ratio > BLACK_THR:
|
||||
processed += 1
|
||||
continue
|
||||
|
||||
pil_img = Image.fromarray(img_arr)
|
||||
results = model.predict(pil_img, imgsz=TILE, conf=CONF,
|
||||
device='cuda:0', verbose=False)
|
||||
|
||||
if not results or results[0].masks is None:
|
||||
processed += 1
|
||||
continue
|
||||
|
||||
masks_data = results[0].masks.data.cpu().numpy()
|
||||
h_r, w_r = masks_data.shape[1], masks_data.shape[2]
|
||||
sx = tw / w_r
|
||||
sy = th / h_r
|
||||
|
||||
for xy in results[0].masks.xy:
|
||||
if len(xy) < 4:
|
||||
continue
|
||||
|
||||
# YOLO 좌표 → 타일 픽셀 좌표
|
||||
xy_tile = [(float(lx) * sx, float(ly) * sy) for lx, ly in xy]
|
||||
|
||||
# PCA 슬라이싱으로 중심선 추출 (타일 픽셀 좌표)
|
||||
cl_tile = polygon_to_centerline(xy_tile, spacing_px=SPACING)
|
||||
if len(cl_tile) < 2:
|
||||
continue
|
||||
|
||||
# overlap 가장자리 제거 + 세계 좌표 변환
|
||||
pts_world = []
|
||||
for lx, ly in cl_tile:
|
||||
if tx + tw < W and lx > (tw - OVERLAP // 2):
|
||||
continue
|
||||
if ty + th < H and ly > (th - OVERLAP // 2):
|
||||
continue
|
||||
pts_world.append(pixel_to_world(tx + lx, ty + ly))
|
||||
|
||||
if len(pts_world) >= 2:
|
||||
msp.add_lwpolyline(pts_world,
|
||||
dxfattribs={"layer": "RAIL_CENTERLINE"})
|
||||
total_polylines += 1
|
||||
|
||||
processed += 1
|
||||
if processed % 500 == 0:
|
||||
pct = processed / total_tiles * 100
|
||||
print(f" 진행: {processed}/{total_tiles} ({pct:.1f}%)"
|
||||
f" | 폴리라인: {total_polylines}")
|
||||
|
||||
src.close()
|
||||
|
||||
doc.saveas(OUT_DXF)
|
||||
print(f"\n완료: {total_polylines}개 폴리라인 → {OUT_DXF}")
|
||||
234
tools/rail_to_dxf.py
Normal file
234
tools/rail_to_dxf.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
rail_to_dxf.py
|
||||
==============
|
||||
X-AnyLabeling JSON 어노테이션에서 레일 중심선을 추출하여 Rhino용 DXF로 저장.
|
||||
|
||||
사용법:
|
||||
python tools/rail_to_dxf.py <annotation.json> [output.dxf]
|
||||
|
||||
예시:
|
||||
python tools/rail_to_dxf.py images/rail.json
|
||||
python tools/rail_to_dxf.py images/rail.json output/rail_centerline.dxf
|
||||
|
||||
라벨 이름 (X-AnyLabeling에서 Finish 후 입력한 이름):
|
||||
rail, railway_track, track, AUTOLABEL_OBJECT 자동 인식
|
||||
다른 이름이면 스크립트 하단 TARGET_LABELS 수정
|
||||
|
||||
Rhino에서 사용:
|
||||
1. DXF Import
|
||||
2. RAIL_CENTERLINE 레이어 선택
|
||||
3. Sweep2 또는 Rail Sweep으로 레일 단면 적용
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
|
||||
# ─── 설정 ─────────────────────────────────────────────────────────────────────
|
||||
TARGET_LABELS = [
|
||||
"rail",
|
||||
"railline",
|
||||
"railway_track",
|
||||
"track",
|
||||
"레일",
|
||||
"철로",
|
||||
"AUTOLABEL_OBJECT",
|
||||
]
|
||||
# line/linestrip 타입은 스켈레톤 불필요 — 직접 DXF 출력
|
||||
LINE_SHAPE_TYPES = {"line", "linestrip", "lines"}
|
||||
SMOOTH_WINDOW = 15 # 중심선 스무딩 강도 (클수록 부드러움, 0=비활성)
|
||||
DOWNSAMPLE_STEP = 8 # Rhino 폴리라인 포인트 간격 (클수록 포인트 수 감소)
|
||||
DILATION_ITER = 3 # 마스크 팽창 반복 (얇은 레일 마스크 연결 보완)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def polygon_to_mask(points, h, w):
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in points], dtype=np.int32)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask
|
||||
|
||||
|
||||
def extract_skeleton(mask):
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated = cv2.dilate(mask, kernel, iterations=DILATION_ITER)
|
||||
return skeletonize(dilated > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def order_skeleton(skeleton):
|
||||
"""스켈레톤 픽셀을 끝점에서 시작해 순서대로 연결."""
|
||||
ys, xs = np.where(skeleton > 0)
|
||||
if len(xs) == 0:
|
||||
return []
|
||||
|
||||
pt_set = set(zip(xs.tolist(), ys.tolist()))
|
||||
|
||||
def neighbors(pt):
|
||||
x, y = pt
|
||||
return [(x+dx, y+dy)
|
||||
for dx in (-1,0,1) for dy in (-1,0,1)
|
||||
if not (dx==0 and dy==0) and (x+dx, y+dy) in pt_set]
|
||||
|
||||
# 끝점(이웃 1개) 찾기 → 없으면 임의 시작
|
||||
endpoints = [p for p in pt_set if len(neighbors(p)) == 1]
|
||||
start = endpoints[0] if endpoints else next(iter(pt_set))
|
||||
|
||||
ordered, visited = [start], {start}
|
||||
current = start
|
||||
|
||||
while True:
|
||||
nbs = [n for n in neighbors(current) if n not in visited]
|
||||
if not nbs:
|
||||
break
|
||||
# 가장 직선에 가까운 방향 우선 선택
|
||||
if len(ordered) >= 2:
|
||||
dx = current[0] - ordered[-2][0]
|
||||
dy = current[1] - ordered[-2][1]
|
||||
nbs.sort(key=lambda n: -((n[0]-current[0])*dx + (n[1]-current[1])*dy))
|
||||
current = nbs[0]
|
||||
visited.add(current)
|
||||
ordered.append(current)
|
||||
|
||||
return ordered
|
||||
|
||||
|
||||
def smooth_polyline(points, window):
|
||||
if window < 3 or len(points) < window:
|
||||
return points
|
||||
pts = np.array(points, dtype=float)
|
||||
half = window // 2
|
||||
out = pts.copy()
|
||||
for i in range(half, len(pts) - half):
|
||||
out[i] = pts[i-half:i+half+1].mean(axis=0)
|
||||
return out.tolist()
|
||||
|
||||
|
||||
def downsample(points, step):
|
||||
if step <= 1 or len(points) <= step:
|
||||
return points
|
||||
sampled = points[::step]
|
||||
if list(sampled[-1]) != list(points[-1]):
|
||||
sampled = list(sampled) + [points[-1]]
|
||||
return sampled
|
||||
|
||||
|
||||
def to_dxf_coords(points, flip_y=True):
|
||||
"""이미지 좌표(Y↓) → DXF 좌표(Y↑)"""
|
||||
if flip_y:
|
||||
return [(float(p[0]), float(-p[1])) for p in points]
|
||||
return [(float(p[0]), float(p[1])) for p in points]
|
||||
|
||||
|
||||
def process(json_path: str, dxf_path: str):
|
||||
import ezdxf
|
||||
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
H = data.get("imageHeight", 1000)
|
||||
W = data.get("imageWidth", 1000)
|
||||
print(f"[이미지] {W} × {H} px")
|
||||
|
||||
shapes = data.get("shapes", [])
|
||||
targets = [s for s in shapes if s.get("label") in TARGET_LABELS]
|
||||
|
||||
if not targets:
|
||||
print("⚠ TARGET_LABELS 일치 없음 → 모든 shape 사용")
|
||||
targets = shapes
|
||||
|
||||
type_counts = {}
|
||||
for s in targets:
|
||||
t = s.get("shape_type", "unknown")
|
||||
type_counts[t] = type_counts.get(t, 0) + 1
|
||||
print(f"[처리] {len(targets)}개: {type_counts}")
|
||||
|
||||
doc = ezdxf.new("R2010")
|
||||
msp = doc.modelspace()
|
||||
doc.layers.add("RAIL_CENTERLINE", color=1) # 빨강 — Sweep 경로
|
||||
doc.layers.add("RAIL_POLYGON", color=3) # 초록 — 원본 마스크 윤곽
|
||||
|
||||
for i, shape in enumerate(targets):
|
||||
label = shape.get("label", f"shape_{i}")
|
||||
pts = shape.get("points", [])
|
||||
shape_type = shape.get("shape_type", "polygon")
|
||||
print(f"\n [{i+1}] label={label!r} type={shape_type!r} pts={len(pts)}")
|
||||
|
||||
if len(pts) < 2:
|
||||
print(" → 포인트 부족, 건너뜀")
|
||||
continue
|
||||
|
||||
# 중복 끝점 제거 (Polygon 도구로 그린 선: 마지막 점이 앞 점과 거의 동일)
|
||||
if len(pts) >= 3:
|
||||
dedup = [pts[0]]
|
||||
for p in pts[1:]:
|
||||
if abs(p[0]-dedup[-1][0]) > 2 or abs(p[1]-dedup[-1][1]) > 2:
|
||||
dedup.append(p)
|
||||
if len(dedup) < len(pts):
|
||||
print(f" 중복점 제거: {len(pts)}pt → {len(dedup)}pt")
|
||||
pts = dedup
|
||||
|
||||
# ── LINE 타입: 스켈레톤 없이 직접 DXF 출력 ──────
|
||||
if shape_type in LINE_SHAPE_TYPES or len(pts) == 2:
|
||||
cl_dxf = to_dxf_coords(pts)
|
||||
msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"})
|
||||
print(f" → 라인 직접 출력 ({len(pts)}pt)")
|
||||
continue
|
||||
|
||||
# ── POLYGON 타입: 마스크 → 스켈레톤 → 중심선 ────
|
||||
if len(pts) < 3:
|
||||
print(" → 폴리곤 포인트 부족, 건너뜀")
|
||||
continue
|
||||
|
||||
poly_dxf = to_dxf_coords(pts)
|
||||
poly_dxf.append(poly_dxf[0])
|
||||
msp.add_lwpolyline(poly_dxf, dxfattribs={"layer": "RAIL_POLYGON"})
|
||||
|
||||
mask = polygon_to_mask(pts, H, W)
|
||||
area = int(mask.sum())
|
||||
print(f" 마스크 면적: {area} px²")
|
||||
if area < 20:
|
||||
print(" → 면적 너무 작음, 건너뜀")
|
||||
continue
|
||||
|
||||
skel = extract_skeleton(mask)
|
||||
skel_n = int(skel.sum())
|
||||
print(f" 스켈레톤 픽셀: {skel_n}")
|
||||
if skel_n < 2:
|
||||
print(" → 스켈레톤 생성 실패, 건너뜀")
|
||||
continue
|
||||
|
||||
ordered = order_skeleton(skel)
|
||||
smoothed = smooth_polyline(ordered, SMOOTH_WINDOW)
|
||||
final_pts = downsample(smoothed, DOWNSAMPLE_STEP)
|
||||
print(f" 중심선 포인트: {len(final_pts)}")
|
||||
if len(final_pts) < 2:
|
||||
continue
|
||||
|
||||
cl_dxf = to_dxf_coords(final_pts)
|
||||
msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"})
|
||||
|
||||
Path(dxf_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.saveas(dxf_path)
|
||||
print(f"\n[완료] DXF 저장: {dxf_path}")
|
||||
print(" 레이어: RAIL_CENTERLINE(빨강) = Sweep 경로")
|
||||
print(" RAIL_POLYGON(초록) = 원본 마스크")
|
||||
print("\nRhino 사용법:")
|
||||
print(" 1. File → Import → DXF 선택")
|
||||
print(" 2. RAIL_CENTERLINE 레이어 선택")
|
||||
print(" 3. Rail 단면 커브 그리기 (레일 단면 약 60x60mm)")
|
||||
print(" 4. Sweep1 또는 Sweep2 명령으로 단면 × 중심선 → 3D 레일 생성")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("사용법: python tools/rail_to_dxf.py <annotation.json> [output.dxf]")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
dxf_file = sys.argv[2] if len(sys.argv) >= 3 else str(Path(json_file).with_suffix(".dxf"))
|
||||
|
||||
process(json_file, dxf_file)
|
||||
870
tools/railway_pipeline.py
Normal file
870
tools/railway_pipeline.py
Normal file
@@ -0,0 +1,870 @@
|
||||
"""
|
||||
railway_pipeline.py
|
||||
===================
|
||||
정사영상(GeoTIFF/PNG)에서 철도 시설물을 자동 검출하여
|
||||
실좌표(UTM/WGS84) 기반 DXF + GeoJSON 출력.
|
||||
|
||||
이미지에 실제로 보이는 것을 그대로 검출 (표준 규격 기반 아님).
|
||||
|
||||
사용법:
|
||||
python tools/railway_pipeline.py <image.tif|png> [output_dir]
|
||||
|
||||
예시:
|
||||
python tools/railway_pipeline.py 경부선.tif output/
|
||||
python tools/railway_pipeline.py 경부선.png output/ # PNG는 GSD=0.05m 가정
|
||||
|
||||
출력:
|
||||
output/railway_rails.dxf - 레일 중심선
|
||||
output/railway_objects.dxf - 전체 시설물 (레이어별)
|
||||
output/railway_objects.geojson - GIS 활용용
|
||||
output/railway_debug.jpg - 검출 결과 시각화
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# ─── 검출 파라미터 ─────────────────────────────────────────────────────────────
|
||||
RAIL = dict(
|
||||
canny_low=30, canny_high=80,
|
||||
hough_threshold=400, min_len=500, max_gap=30,
|
||||
cluster_dist=14, min_total_len=3000,
|
||||
)
|
||||
SLEEPER = dict(
|
||||
search_width=60, # 레일 양옆 검색 폭 (px)
|
||||
min_area=30, # 최소 면적
|
||||
max_area=2000, # 최대 면적
|
||||
aspect_min=2.0, # 최소 가로세로비 (침목은 길쭉)
|
||||
)
|
||||
POLE = dict(
|
||||
search_margin=200, # 레일 옆 검색 범위 (px)
|
||||
min_radius=3, # 최소 반지름 (px)
|
||||
max_radius=25, # 최대 반지름 (px)
|
||||
min_dist=40, # 검출 최소 간격 (px)
|
||||
param1=50, param2=25, # Hough Circle 파라미터
|
||||
)
|
||||
CBOX = dict(
|
||||
search_margin=300, # 레일 옆 검색 범위 (px)
|
||||
min_area=100,
|
||||
max_area=8000,
|
||||
aspect_min=1.2, # 직사각형 비율
|
||||
aspect_max=6.0,
|
||||
min_solidity=0.75, # 컨투어 충실도 (직사각형에 가까울수록 1)
|
||||
)
|
||||
GABOR = dict(
|
||||
lambdas=[5, 7, 10], # 침목 간격 범위 (px, 축소 이미지 기준)
|
||||
# 5cm GSD, 8000px 폭(scale≈0.54): 600mm/50mm*0.54 ≈ 6.5px
|
||||
sigma=3.0, # Gabor 커널 폭
|
||||
gamma=0.5, # 종횡비 (1=원형, 0.5=타원)
|
||||
n_angles=12, # 방향 수 (0°~165°, 15° 간격)
|
||||
thresh_pct=60, # 응답 임계값 백분위 (높을수록 엄격)
|
||||
min_area=4000, # 최소 연결성분 면적 (작은 노이즈 제거)
|
||||
)
|
||||
SCALE_FACTOR = 8000 # 침목 패턴 감지를 위해 높은 해상도 유지
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 좌표 변환
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class GeoTransform:
|
||||
"""픽셀좌표 ↔ 실좌표 변환."""
|
||||
|
||||
def __init__(self, image_path: str):
|
||||
self.crs = None
|
||||
self.affine = None
|
||||
self.origin = (0.0, 0.0)
|
||||
self.pixel_size = (0.05, 0.05) # 기본 5cm GSD
|
||||
self._load(image_path)
|
||||
|
||||
def _load(self, path: str):
|
||||
if path.lower().endswith(('.tif', '.tiff')):
|
||||
try:
|
||||
import rasterio
|
||||
with rasterio.open(path) as r:
|
||||
self.crs = str(r.crs)
|
||||
t = r.transform
|
||||
self.affine = t
|
||||
self.pixel_size = (abs(t.a), abs(t.e))
|
||||
self.origin = (t.c, t.f)
|
||||
print(f" CRS: {self.crs}")
|
||||
print(f" 픽셀크기: {self.pixel_size[0]:.4f}m x {self.pixel_size[1]:.4f}m")
|
||||
print(f" 원점: ({self.origin[0]:.2f}, {self.origin[1]:.2f})")
|
||||
except Exception as e:
|
||||
print(f" [경고] rasterio 실패: {e} → 픽셀좌표 사용")
|
||||
else:
|
||||
# PNG: world file 탐색
|
||||
wld = Path(path).with_suffix('.pgw')
|
||||
if not wld.exists():
|
||||
wld = Path(path).with_suffix('.wld')
|
||||
if wld.exists():
|
||||
vals = [float(l.strip()) for l in wld.read_text().splitlines() if l.strip()]
|
||||
if len(vals) >= 6:
|
||||
self.pixel_size = (abs(vals[0]), abs(vals[3]))
|
||||
self.origin = (vals[4], vals[5])
|
||||
print(f" World file 적용: origin={self.origin}")
|
||||
else:
|
||||
print(f" World file 없음 → 픽셀좌표 사용 (GSD=0.05m)")
|
||||
|
||||
def px_to_world(self, px: float, py: float):
|
||||
"""픽셀 (col, row) → 실좌표 (x, y)."""
|
||||
if self.affine:
|
||||
from rasterio.transform import xy as rio_xy
|
||||
try:
|
||||
x, y = self.affine * (px, py)
|
||||
return float(x), float(y)
|
||||
except Exception:
|
||||
pass
|
||||
x = self.origin[0] + px * self.pixel_size[0]
|
||||
y = self.origin[1] - py * self.pixel_size[1]
|
||||
return x, y
|
||||
|
||||
def scale_coords(self, pts, scale: float):
|
||||
"""축소된 이미지의 픽셀좌표를 원본 픽셀좌표로 변환 후 실좌표 출력."""
|
||||
return [self.px_to_world(p[0] / scale, p[1] / scale) for p in pts]
|
||||
|
||||
def scale_point(self, px: float, py: float, scale: float):
|
||||
return self.px_to_world(px / scale, py / scale)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 이미지 로드 (한글 경로 대응)
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def load_image(path: str):
|
||||
"""TIF/PNG 모두 BGR numpy array로 반환."""
|
||||
if path.lower().endswith(('.tif', '.tiff')):
|
||||
try:
|
||||
import rasterio
|
||||
with rasterio.open(path) as r:
|
||||
if r.count >= 3:
|
||||
arr = np.dstack([r.read(i) for i in [1, 2, 3]])
|
||||
# uint16 → uint8 변환
|
||||
if arr.dtype != np.uint8:
|
||||
arr = (arr / arr.max() * 255).astype(np.uint8)
|
||||
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
band = r.read(1)
|
||||
if band.dtype != np.uint8:
|
||||
band = (band / band.max() * 255).astype(np.uint8)
|
||||
return cv2.cvtColor(band, cv2.COLOR_GRAY2BGR)
|
||||
except Exception as e:
|
||||
print(f" rasterio 읽기 실패: {e}, PIL 시도...")
|
||||
# PNG 또는 fallback
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 자갈도상 마스크
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def extract_ballast_mask(img):
|
||||
"""자갈도상(ballast) 구역 마스크 추출.
|
||||
항공 정사영상에서 자갈도상은 밝고(V>70) 채도 낮음(S<60) — HSV 임계값 후
|
||||
큰 연결성분만 유지하여 선로 띠 영역만 남김.
|
||||
"""
|
||||
H, W = img.shape[:2]
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# 낮은 채도 + 중간-높은 밝기 = 자갈/회색 (채도 45 이하로 더 엄격하게)
|
||||
mask = cv2.inRange(hsv,
|
||||
np.array([0, 0, 80], np.uint8),
|
||||
np.array([180, 45, 210], np.uint8))
|
||||
|
||||
# 작은 노이즈 제거
|
||||
k_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k_open, iterations=2)
|
||||
|
||||
# 팽창: 자갈 구역 연결 + 레일 엣지까지 포함 (2회로 줄여 과팽창 방지)
|
||||
k_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
||||
mask = cv2.dilate(mask, k_dil, iterations=2)
|
||||
|
||||
# 연결성분 중 이미지 면적 0.5% 이상만 유지 (도로·건물 등 소규모 제거)
|
||||
min_area = max(H * W * 0.005, 5000)
|
||||
n, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
||||
result = np.zeros_like(mask)
|
||||
for lbl in range(1, n):
|
||||
if stats[lbl, cv2.CC_STAT_AREA] >= min_area:
|
||||
result[labels == lbl] = 255
|
||||
|
||||
kept = int((result > 0).sum())
|
||||
total = H * W
|
||||
print(f" 자갈도상 마스크: {kept:,}px ({kept/total*100:.1f}%)")
|
||||
return result
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Gabor 침목 패턴 마스크 + 레일 검출 (통합)
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def detect_rails_gabor(img):
|
||||
"""
|
||||
[1단계] Gabor 필터 뱅크로 침목 줄무늬 패턴 영역 추출
|
||||
→ 자갈도상 색상이 아닌 침목 주기 텍스처로 선로 구역 판별
|
||||
→ 도로는 이 패턴 없으므로 자동 제거
|
||||
|
||||
[2단계] 침목 마스크 안에서 Hough로 개별 레일 중심선 검출
|
||||
|
||||
Returns: (rail_polylines, gabor_mask)
|
||||
"""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
H, W = gray.shape
|
||||
|
||||
# ── Gabor 필터 뱅크 적용 ──────────────────────────────────
|
||||
response_max = np.zeros((H, W), dtype=np.float32)
|
||||
n_ang = GABOR['n_angles']
|
||||
for i in range(n_ang):
|
||||
theta = i * np.pi / n_ang # 0 ~ π (모든 방향)
|
||||
for lam in GABOR['lambdas']:
|
||||
ksize = max(int(GABOR['sigma'] * 6) | 1, 7) # 홀수 보장
|
||||
kernel = cv2.getGaborKernel(
|
||||
(ksize, ksize),
|
||||
GABOR['sigma'], theta, lam,
|
||||
GABOR['gamma'], 0, cv2.CV_32F
|
||||
)
|
||||
resp = np.abs(cv2.filter2D(gray.astype(np.float32), cv2.CV_32F, kernel))
|
||||
response_max = np.maximum(response_max, resp)
|
||||
|
||||
# ── 응답 정규화 + 임계값 ──────────────────────────────────
|
||||
resp_u8 = cv2.normalize(response_max, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
||||
nz = resp_u8[resp_u8 > 0]
|
||||
thresh_val = int(np.percentile(nz, GABOR['thresh_pct'])) if len(nz) else 128
|
||||
_, binary = cv2.threshold(resp_u8, thresh_val, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# ── 모폴로지 정리 ─────────────────────────────────────────
|
||||
k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
||||
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, k5, iterations=2)
|
||||
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, k3, iterations=1)
|
||||
|
||||
# ── 큰 연결성분만 유지 (소규모 노이즈 제거) ───────────────
|
||||
n_cc, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
||||
gabor_mask = np.zeros_like(binary)
|
||||
for lbl in range(1, n_cc):
|
||||
if stats[lbl, cv2.CC_STAT_AREA] >= GABOR['min_area']:
|
||||
gabor_mask[labels == lbl] = 255
|
||||
|
||||
kept = int((gabor_mask > 0).sum())
|
||||
print(f" Gabor 마스크: {kept:,}px ({kept / (H * W) * 100:.1f}%)")
|
||||
|
||||
# ── Gabor 마스크 내에서 Hough 레일 중심선 검출 ───────────
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(gray)
|
||||
blurred = cv2.GaussianBlur(enhanced, (5, 5), 0)
|
||||
edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high'])
|
||||
edges = cv2.bitwise_and(edges, edges, mask=gabor_mask)
|
||||
print(f" 마스크 내 엣지: {int(edges.sum() // 255):,}px")
|
||||
|
||||
lines_raw = cv2.HoughLinesP(
|
||||
edges, 1, np.pi / 180,
|
||||
threshold=RAIL['hough_threshold'],
|
||||
minLineLength=RAIL['min_len'],
|
||||
maxLineGap=RAIL['max_gap'],
|
||||
)
|
||||
if lines_raw is None:
|
||||
print(" Hough 검출 없음")
|
||||
return [], gabor_mask
|
||||
|
||||
lines = [tuple(l[0]) for l in lines_raw]
|
||||
print(f" Hough 검출: {len(lines)}개")
|
||||
|
||||
# ── 클러스터링 → 폴리라인 ────────────────────────────────
|
||||
# 곡선 구간 지원: 각도 기준 아닌 미드포인트 거리 기준으로 클러스터링
|
||||
# 같은 레일의 곡선 세그먼트들은 인접 미드포인트를 가짐 → 체인으로 연결
|
||||
ATOL = RAIL['cluster_dist']
|
||||
|
||||
# 각도별 그룹화 후 수직거리 기반 클러스터링
|
||||
# 분기기 구간: 강제 연결 없이 세그먼트 단위로 유지 (분기기는 수동 연결)
|
||||
def angle(x1, y1, x2, y2):
|
||||
return np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180
|
||||
|
||||
def perp_dist(x1, y1, x2, y2, px, py):
|
||||
dx, dy = x2 - x1, y2 - y1
|
||||
L = math.hypot(dx, dy)
|
||||
if L == 0:
|
||||
return math.hypot(px - x1, py - y1)
|
||||
return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / L
|
||||
|
||||
angle_groups = defaultdict(list)
|
||||
for l in lines:
|
||||
key = round(angle(*l) / 5) * 5
|
||||
angle_groups[key].append(l)
|
||||
|
||||
clusters = []
|
||||
for grp in angle_groups.values():
|
||||
used = [False] * len(grp)
|
||||
for i, li in enumerate(grp):
|
||||
if used[i]:
|
||||
continue
|
||||
cl = [li]
|
||||
used[i] = True
|
||||
for j, lj in enumerate(grp):
|
||||
if used[j]:
|
||||
continue
|
||||
mxj = ((lj[0]+lj[2])/2, (lj[1]+lj[3])/2)
|
||||
if perp_dist(*li, *mxj) < ATOL:
|
||||
cl.append(lj)
|
||||
used[j] = True
|
||||
clusters.append(cl)
|
||||
|
||||
def merge(cluster):
|
||||
pts = []
|
||||
for x1, y1, x2, y2 in cluster:
|
||||
pts += [(x1, y1), (x2, y2)]
|
||||
arr = np.array(pts, float)
|
||||
mean = arr.mean(0)
|
||||
cov = np.cov((arr - mean).T)
|
||||
if cov.ndim < 2:
|
||||
direction = np.array([1.0, 0.0])
|
||||
else:
|
||||
ev, evec = np.linalg.eig(cov)
|
||||
direction = evec[:, np.argmax(ev)]
|
||||
proj = (arr - mean).dot(direction)
|
||||
sorted_pts = arr[np.argsort(proj)]
|
||||
merged = [sorted_pts[0]]
|
||||
for p in sorted_pts[1:]:
|
||||
if math.hypot(p[0] - merged[-1][0], p[1] - merged[-1][1]) > 5:
|
||||
merged.append(p)
|
||||
return merged
|
||||
|
||||
def poly_len(poly):
|
||||
return sum(
|
||||
math.hypot(poly[i + 1][0] - poly[i][0], poly[i + 1][1] - poly[i][1])
|
||||
for i in range(len(poly) - 1)
|
||||
)
|
||||
|
||||
result = []
|
||||
for cl in clusters:
|
||||
poly = merge(cl)
|
||||
if poly_len(poly) >= RAIL['min_total_len']:
|
||||
result.append(poly)
|
||||
result.sort(key=lambda p: -poly_len(p))
|
||||
return result, gabor_mask
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 레일 검출 (Hough — fallback용, 주 검출은 detect_rails_gabor 사용)
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def detect_rails(img, ballast_mask=None):
|
||||
"""Hough Line으로 레일 중심선 폴리라인 반환."""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
||||
enhanced = clahe.apply(gray)
|
||||
blurred = cv2.GaussianBlur(enhanced, (5,5), 0)
|
||||
edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high'])
|
||||
|
||||
# 자갈도상 마스크 적용 — 선로 구역 외 엣지 제거
|
||||
if ballast_mask is not None:
|
||||
edges = cv2.bitwise_and(edges, edges, mask=ballast_mask)
|
||||
print(f" 마스크 적용 후 엣지: {int(edges.sum()//255):,}px")
|
||||
|
||||
lines_raw = cv2.HoughLinesP(edges, 1, np.pi/180,
|
||||
threshold=RAIL['hough_threshold'],
|
||||
minLineLength=RAIL['min_len'],
|
||||
maxLineGap=RAIL['max_gap'])
|
||||
if lines_raw is None:
|
||||
return []
|
||||
|
||||
lines = [tuple(l[0]) for l in lines_raw]
|
||||
|
||||
def angle(x1,y1,x2,y2):
|
||||
return np.degrees(np.arctan2(y2-y1, x2-x1)) % 180
|
||||
|
||||
def line_len(x1,y1,x2,y2):
|
||||
return math.hypot(x2-x1, y2-y1)
|
||||
|
||||
def midpoint(x1,y1,x2,y2):
|
||||
return ((x1+x2)/2, (y1+y2)/2)
|
||||
|
||||
def perp_dist(x1,y1,x2,y2,px,py):
|
||||
dx,dy = x2-x1, y2-y1
|
||||
L = math.hypot(dx,dy)
|
||||
if L == 0: return math.hypot(px-x1, py-y1)
|
||||
return abs(dy*px - dx*py + x2*y1 - y2*x1) / L
|
||||
|
||||
# 방향 필터 불필요 — 자갈도상 마스크가 공간 제약 역할을 하므로
|
||||
# 커브 구간도 포함하려면 모든 방향 허용
|
||||
print(f" Hough 검출: {len(lines)}개")
|
||||
|
||||
# 각도별 그룹화 후 거리 클러스터링
|
||||
ATOL = RAIL['cluster_dist']
|
||||
angle_groups = defaultdict(list)
|
||||
for l in lines:
|
||||
key = round(angle(*l) / 5) * 5
|
||||
angle_groups[key].append(l)
|
||||
|
||||
clusters = []
|
||||
for grp in angle_groups.values():
|
||||
used = [False]*len(grp)
|
||||
for i, li in enumerate(grp):
|
||||
if used[i]: continue
|
||||
cl = [li]; used[i] = True
|
||||
mx, my = midpoint(*li)
|
||||
for j, lj in enumerate(grp):
|
||||
if used[j]: continue
|
||||
if perp_dist(*li, *midpoint(*lj)) < ATOL:
|
||||
cl.append(lj); used[j] = True
|
||||
clusters.append(cl)
|
||||
|
||||
def merge(cluster):
|
||||
pts = []
|
||||
for x1,y1,x2,y2 in cluster:
|
||||
pts += [(x1,y1),(x2,y2)]
|
||||
arr = np.array(pts, float)
|
||||
mean = arr.mean(0)
|
||||
cov = np.cov((arr - mean).T)
|
||||
if cov.ndim < 2: direction = np.array([1.0,0.0])
|
||||
else:
|
||||
ev, evec = np.linalg.eig(cov)
|
||||
direction = evec[:, np.argmax(ev)]
|
||||
proj = (arr - mean).dot(direction)
|
||||
idx = np.argsort(proj)
|
||||
sorted_pts = arr[idx]
|
||||
merged = [sorted_pts[0]]
|
||||
for p in sorted_pts[1:]:
|
||||
if math.hypot(p[0]-merged[-1][0], p[1]-merged[-1][1]) > 5:
|
||||
merged.append(p)
|
||||
return merged
|
||||
|
||||
def poly_len(poly):
|
||||
return sum(math.hypot(poly[i+1][0]-poly[i][0], poly[i+1][1]-poly[i][1])
|
||||
for i in range(len(poly)-1))
|
||||
|
||||
result = []
|
||||
for cl in clusters:
|
||||
poly = merge(cl)
|
||||
if poly_len(poly) >= RAIL['min_total_len']:
|
||||
result.append(poly)
|
||||
result.sort(key=lambda p: -poly_len(p))
|
||||
return result
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 침목 검출
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def detect_sleepers(img, rail_polys, gabor_mask=None):
|
||||
"""레일 영역 주변에서 침목(가로 직사각형) 검출."""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# 레일 위치로 마스크 생성
|
||||
mask = np.zeros(gray.shape, np.uint8)
|
||||
sw = SLEEPER['search_width']
|
||||
for poly in rail_polys:
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
|
||||
cv2.polylines(mask, [pts], False, 255, sw * 2)
|
||||
masked = cv2.bitwise_and(thresh, thresh, mask=mask)
|
||||
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
||||
cleaned = cv2.morphologyEx(masked, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
contours, _ = cv2.findContours(cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
sleepers = []
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if not (SLEEPER['min_area'] <= area <= SLEEPER['max_area']):
|
||||
continue
|
||||
rect = cv2.minAreaRect(cnt)
|
||||
(cx, cy), (w, h), angle = rect
|
||||
if w == 0 or h == 0: continue
|
||||
aspect = max(w, h) / min(w, h)
|
||||
if aspect < SLEEPER['aspect_min']:
|
||||
continue
|
||||
sleepers.append({
|
||||
'center': (float(cx), float(cy)),
|
||||
'size': (float(w), float(h)),
|
||||
'angle': float(angle),
|
||||
'area': float(area),
|
||||
})
|
||||
return sleepers
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 전철주 검출
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def detect_poles(img, rail_polys):
|
||||
"""레일 주변에서 원형/점형 전철주 검출 (Hough Circle + Blob)."""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
blurred = cv2.GaussianBlur(gray, (9,9), 2)
|
||||
|
||||
# 레일 근처 마스크
|
||||
mask = np.zeros(gray.shape, np.uint8)
|
||||
for poly in rail_polys:
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
|
||||
cv2.polylines(mask, [pts], False, 255, POLE['search_margin'] * 2)
|
||||
|
||||
masked = cv2.bitwise_and(blurred, blurred, mask=mask)
|
||||
|
||||
circles = cv2.HoughCircles(
|
||||
masked, cv2.HOUGH_GRADIENT, dp=1,
|
||||
minDist=POLE['min_dist'],
|
||||
param1=POLE['param1'],
|
||||
param2=POLE['param2'],
|
||||
minRadius=POLE['min_radius'],
|
||||
maxRadius=POLE['max_radius'],
|
||||
)
|
||||
|
||||
poles = []
|
||||
if circles is not None:
|
||||
for (cx, cy, r) in circles[0]:
|
||||
poles.append({
|
||||
'center': (float(cx), float(cy)),
|
||||
'radius': float(r),
|
||||
})
|
||||
return poles
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 컨트롤박스 검출
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def detect_control_boxes(img, rail_polys):
|
||||
"""레일 근처 직사각형 객체 검출."""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
blurred = cv2.GaussianBlur(gray, (5,5), 0)
|
||||
edges = cv2.Canny(blurred, 30, 90)
|
||||
|
||||
# 레일 근처 마스크
|
||||
mask = np.zeros(gray.shape, np.uint8)
|
||||
for poly in rail_polys:
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
|
||||
cv2.polylines(mask, [pts], False, 255, CBOX['search_margin'] * 2)
|
||||
masked_edges = cv2.bitwise_and(edges, edges, mask=mask)
|
||||
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
|
||||
closed = cv2.morphologyEx(masked_edges, cv2.MORPH_CLOSE, kernel, iterations=2)
|
||||
|
||||
contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
boxes = []
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if not (CBOX['min_area'] <= area <= CBOX['max_area']):
|
||||
continue
|
||||
hull = cv2.convexHull(cnt)
|
||||
hull_area = cv2.contourArea(hull)
|
||||
if hull_area == 0: continue
|
||||
solidity = area / hull_area
|
||||
if solidity < CBOX['min_solidity']:
|
||||
continue
|
||||
rect = cv2.minAreaRect(cnt)
|
||||
(cx, cy), (w, h), angle = rect
|
||||
if w == 0 or h == 0: continue
|
||||
aspect = max(w, h) / min(w, h)
|
||||
if not (CBOX['aspect_min'] <= aspect <= CBOX['aspect_max']):
|
||||
continue
|
||||
boxes.append({
|
||||
'center': (float(cx), float(cy)),
|
||||
'size': (float(w), float(h)),
|
||||
'angle': float(angle),
|
||||
'area': float(area),
|
||||
})
|
||||
return boxes
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# DXF 출력
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def save_dxf(output_path, rails_world, sleepers_world, poles_world, boxes_world):
|
||||
import ezdxf
|
||||
doc = ezdxf.new("R2010")
|
||||
msp = doc.modelspace()
|
||||
|
||||
# 레이어 정의
|
||||
doc.layers.add("RAIL", color=1) # 빨강
|
||||
doc.layers.add("SLEEPER", color=3) # 초록
|
||||
doc.layers.add("POLE", color=5) # 파랑
|
||||
doc.layers.add("CONTROL_BOX", color=6) # 마젠타
|
||||
|
||||
# 레일 중심선
|
||||
for poly in rails_world:
|
||||
if len(poly) >= 2:
|
||||
msp.add_lwpolyline(poly, dxfattribs={"layer": "RAIL"})
|
||||
|
||||
# 침목
|
||||
for s in sleepers_world:
|
||||
cx, cy = s['world_center']
|
||||
w, h = s['size']
|
||||
ang = s['angle']
|
||||
scale = s.get('scale', 1.0)
|
||||
# 회전된 직사각형
|
||||
box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang))
|
||||
msp.add_lwpolyline(
|
||||
[(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))],
|
||||
dxfattribs={"layer": "SLEEPER"}
|
||||
)
|
||||
|
||||
# 전철주 (원)
|
||||
for p in poles_world:
|
||||
cx, cy = p['world_center']
|
||||
r = p.get('world_radius', 1.0)
|
||||
msp.add_circle((cx, -cy), r, dxfattribs={"layer": "POLE"})
|
||||
|
||||
# 컨트롤박스
|
||||
for b in boxes_world:
|
||||
cx, cy = b['world_center']
|
||||
w, h = b['size']
|
||||
ang = b['angle']
|
||||
scale = b.get('scale', 1.0)
|
||||
box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang))
|
||||
msp.add_lwpolyline(
|
||||
[(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))],
|
||||
dxfattribs={"layer": "CONTROL_BOX"}
|
||||
)
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.saveas(output_path)
|
||||
print(f" DXF: {output_path}")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# GeoJSON 출력
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def save_geojson(output_path, rails_world, sleepers_world, poles_world, boxes_world):
|
||||
features = []
|
||||
|
||||
for i, poly in enumerate(rails_world):
|
||||
features.append({
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "LineString", "coordinates": [[x, y] for x, y in poly]},
|
||||
"properties": {"type": "rail", "id": i}
|
||||
})
|
||||
for i, s in enumerate(sleepers_world):
|
||||
cx, cy = s['world_center']
|
||||
features.append({
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "Point", "coordinates": [cx, cy]},
|
||||
"properties": {"type": "sleeper", "id": i, "angle": s['angle']}
|
||||
})
|
||||
for i, p in enumerate(poles_world):
|
||||
cx, cy = p['world_center']
|
||||
features.append({
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "Point", "coordinates": [cx, cy]},
|
||||
"properties": {"type": "pole", "id": i, "radius_m": p.get('world_radius', 0)}
|
||||
})
|
||||
for i, b in enumerate(boxes_world):
|
||||
cx, cy = b['world_center']
|
||||
features.append({
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "Point", "coordinates": [cx, cy]},
|
||||
"properties": {"type": "control_box", "id": i,
|
||||
"width_m": b.get('world_w', 0), "height_m": b.get('world_h', 0)}
|
||||
})
|
||||
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
if not output_path.endswith('.geojson'):
|
||||
output_path += '.geojson'
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(geojson, f, ensure_ascii=False, indent=2)
|
||||
print(f" GeoJSON: {output_path}")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 시각화
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def save_debug(img, rails, sleepers, poles, boxes, output_path, ballast_mask=None):
|
||||
vis = img.copy()
|
||||
# 자갈도상 마스크 반투명 오버레이 (노란색)
|
||||
if ballast_mask is not None:
|
||||
overlay = vis.copy()
|
||||
overlay[ballast_mask > 0] = (0, 200, 200)
|
||||
cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis)
|
||||
# 레일
|
||||
for poly in rails:
|
||||
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
|
||||
cv2.polylines(vis, [pts], False, (0,0,255), 2)
|
||||
# 침목
|
||||
for s in sleepers:
|
||||
cx, cy = int(s['center'][0]), int(s['center'][1])
|
||||
rect = ((cx,cy), (s['size'][0], s['size'][1]), s['angle'])
|
||||
box = cv2.boxPoints(rect).astype(int)
|
||||
cv2.drawContours(vis, [box], 0, (0,255,0), 1)
|
||||
# 전철주
|
||||
for p in poles:
|
||||
cx, cy, r = int(p['center'][0]), int(p['center'][1]), int(p['radius'])
|
||||
cv2.circle(vis, (cx, cy), r, (255,0,0), 2)
|
||||
# 컨트롤박스
|
||||
for b in boxes:
|
||||
cx, cy = int(b['center'][0]), int(b['center'][1])
|
||||
rect = ((cx,cy), (b['size'][0], b['size'][1]), b['angle'])
|
||||
box = cv2.boxPoints(rect).astype(int)
|
||||
cv2.drawContours(vis, [box], 0, (255,0,255), 2)
|
||||
|
||||
# 범례
|
||||
legend = [("RAIL", (0,0,255)),
|
||||
("SLEEPER", (0,255,0)),
|
||||
("POLE", (255,0,0)),
|
||||
("CONTROL_BOX", (255,0,255))]
|
||||
for i, (name, color) in enumerate(legend):
|
||||
cv2.rectangle(vis, (10, 10+i*28), (24, 24+i*28), color, -1)
|
||||
cv2.putText(vis, name, (30, 22+i*28),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
|
||||
with open(str(output_path), 'wb') as f:
|
||||
_, buf = cv2.imencode('.jpg', vis, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
f.write(buf.tobytes())
|
||||
print(f" 시각화: {output_path}")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# 메인 파이프라인
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def process(image_path: str, output_dir: str):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"입력: {image_path}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 좌표 변환기 초기화
|
||||
print("\n[좌표계]")
|
||||
geo = GeoTransform(image_path)
|
||||
|
||||
# 이미지 로드
|
||||
print("\n[이미지 로드]")
|
||||
img = load_image(image_path)
|
||||
if img is None:
|
||||
print("오류: 이미지를 열 수 없습니다.")
|
||||
sys.exit(1)
|
||||
H, W = img.shape[:2]
|
||||
print(f" 크기: {W} x {H} px")
|
||||
|
||||
# 처리용 축소
|
||||
scale = 1.0
|
||||
if W > SCALE_FACTOR:
|
||||
scale = SCALE_FACTOR / W
|
||||
small = cv2.resize(img, (int(W*scale), int(H*scale)))
|
||||
print(f" 축소: {int(W*scale)} x {int(H*scale)} (x{scale:.2f})")
|
||||
else:
|
||||
small = img.copy()
|
||||
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
stem = Path(image_path).stem
|
||||
|
||||
# ── Gabor 침목 패턴 검출 + 레일 중심선 ───────────────────
|
||||
print("\n[1] Gabor 침목 패턴 → 레일 검출...")
|
||||
rails_px, gabor_mask = detect_rails_gabor(small)
|
||||
print(f" 검출: {len(rails_px)}개")
|
||||
|
||||
# Gabor 마스크 저장 (확인용)
|
||||
mask_path = out / f"{stem}_gabor_mask.jpg"
|
||||
with open(str(mask_path), 'wb') as _f:
|
||||
_, _buf = cv2.imencode('.jpg', gabor_mask)
|
||||
_f.write(_buf.tobytes())
|
||||
print(f" Gabor 마스크: {mask_path}")
|
||||
|
||||
# 레일 실좌표 변환
|
||||
rails_world = []
|
||||
for poly in rails_px:
|
||||
world_pts = geo.scale_coords(poly, scale)
|
||||
rails_world.append(world_pts)
|
||||
|
||||
# ── 침목 검출 ──────────────────────────────────────────
|
||||
print("\n[2] 침목 검출...")
|
||||
sleepers_px = detect_sleepers(small, rails_px, gabor_mask)
|
||||
print(f" 검출: {len(sleepers_px)}개")
|
||||
|
||||
sleepers_world = []
|
||||
for s in sleepers_px:
|
||||
cx, cy = s['center']
|
||||
wx, wy = geo.scale_point(cx, cy, scale)
|
||||
ps = geo.pixel_size[0] / scale
|
||||
sleepers_world.append({
|
||||
**s,
|
||||
'world_center': (wx, wy),
|
||||
'scale': scale,
|
||||
})
|
||||
|
||||
# ── 전철주 검출 ────────────────────────────────────────
|
||||
print("\n[3] 전철주 검출...")
|
||||
poles_px = detect_poles(small, rails_px)
|
||||
print(f" 검출: {len(poles_px)}개")
|
||||
|
||||
poles_world = []
|
||||
for p in poles_px:
|
||||
cx, cy = p['center']
|
||||
wx, wy = geo.scale_point(cx, cy, scale)
|
||||
r_m = p['radius'] * geo.pixel_size[0] / scale
|
||||
poles_world.append({
|
||||
**p,
|
||||
'world_center': (wx, wy),
|
||||
'world_radius': r_m,
|
||||
})
|
||||
|
||||
# ── 컨트롤박스 검출 ────────────────────────────────────
|
||||
print("\n[4] 컨트롤박스 검출...")
|
||||
boxes_px = detect_control_boxes(small, rails_px)
|
||||
print(f" 검출: {len(boxes_px)}개")
|
||||
|
||||
boxes_world = []
|
||||
ps = geo.pixel_size[0]
|
||||
for b in boxes_px:
|
||||
cx, cy = b['center']
|
||||
wx, wy = geo.scale_point(cx, cy, scale)
|
||||
w_m = b['size'][0] * ps / scale
|
||||
h_m = b['size'][1] * ps / scale
|
||||
boxes_world.append({
|
||||
**b,
|
||||
'world_center': (wx, wy),
|
||||
'world_w': w_m,
|
||||
'world_h': h_m,
|
||||
'scale': scale,
|
||||
})
|
||||
|
||||
# ── 결과 요약 ──────────────────────────────────────────
|
||||
print(f"\n{'='*60}")
|
||||
print(f"검출 결과 요약")
|
||||
print(f" 레일: {len(rails_world):>5}개")
|
||||
print(f" 침목: {len(sleepers_world):>5}개")
|
||||
print(f" 전철주: {len(poles_world):>5}개")
|
||||
print(f" 컨트롤박스: {len(boxes_world):>5}개")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# ── 출력 ──────────────────────────────────────────────
|
||||
print("\n[출력]")
|
||||
save_dxf(
|
||||
str(out / f"{stem}_railway.dxf"),
|
||||
rails_world, sleepers_world, poles_world, boxes_world
|
||||
)
|
||||
save_geojson(
|
||||
str(out / f"{stem}_railway.geojson"),
|
||||
rails_world, sleepers_world, poles_world, boxes_world
|
||||
)
|
||||
save_debug(
|
||||
small, rails_px, sleepers_px, poles_px, boxes_px,
|
||||
out / f"{stem}_railway_debug.jpg",
|
||||
ballast_mask=gabor_mask
|
||||
)
|
||||
|
||||
print(f"\n완료.")
|
||||
print(f" DXF 레이어: RAIL(빨강) / SLEEPER(초록) / POLE(파랑) / CONTROL_BOX(마젠타)")
|
||||
print(f" Rhino: Import DXF → 레이어별 3D 모델 교체")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("사용법: python tools/railway_pipeline.py <image.tif|png> [output_dir]")
|
||||
sys.exit(1)
|
||||
|
||||
img_path = sys.argv[1]
|
||||
output_dir = sys.argv[2] if len(sys.argv) >= 3 else "output"
|
||||
process(img_path, output_dir)
|
||||
535
tools/render_skeleton_overlay.py
Normal file
535
tools/render_skeleton_overlay.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""SAM 원본 폴리곤의 skeleton을 까만색으로 이미지에 겹쳐 표시.
|
||||
union mask 기준 skeleton을 그리므로, 겹쳐있는 fragment들의 합쳐진 골격이 보임.
|
||||
|
||||
사용:
|
||||
python tools/render_skeleton_overlay.py \
|
||||
--image data/역사이미지/slope/DJI_20260306113842_0006.JPG \
|
||||
--label output/autolabel/raw_0006/labels/DJI_20260306113842_0006.txt \
|
||||
--output output/autolabel/raw_0006/vis_skeleton.jpg
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
import requests
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
SAM3_MODEL_ID = "segment_anything_3"
|
||||
|
||||
SUB_CATEGORIES = [
|
||||
{"name": "pole_shaft", "name_kr": "전주기둥",
|
||||
"prompt": "vertical steel pole shaft, cylindrical mast, tubular pole body",
|
||||
"keywords": ["pole", "shaft", "mast", "tubular"],
|
||||
"color_bgr": [0, 200, 255]},
|
||||
{"name": "bracket_arm", "name_kr": "브라켓/암",
|
||||
"prompt": "horizontal cantilever arm, bracket arm, cross arm, pole cross arm",
|
||||
"keywords": ["bracket", "cantilever", "cross arm", "arm"],
|
||||
"color_bgr": [255, 100, 0]},
|
||||
{"name": "insulator", "name_kr": "애자",
|
||||
"prompt": "ceramic insulator, glass insulator, suspension insulator, string of insulators",
|
||||
"keywords": ["insulator"],
|
||||
"color_bgr": [0, 255, 255]},
|
||||
{"name": "wire", "name_kr": "가선/조가선",
|
||||
"prompt": "overhead contact wire, catenary wire, messenger wire, electric cable",
|
||||
"keywords": ["wire", "cable", "catenary"],
|
||||
"color_bgr": [200, 200, 255]},
|
||||
]
|
||||
|
||||
|
||||
def _sam3_call(crop_bgr: np.ndarray, prompt: str, conf: float = 0.15) -> list:
|
||||
h, w = crop_bgr.shape[:2]
|
||||
scale = 1.0
|
||||
if max(h, w) > 1280:
|
||||
scale = 1280 / max(h, w)
|
||||
crop_bgr = cv2.resize(crop_bgr, (int(w * scale), int(h * scale)))
|
||||
_, buf = cv2.imencode(".jpg", crop_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
b64 = base64.b64encode(buf).decode()
|
||||
payload = {
|
||||
"model": SAM3_MODEL_ID,
|
||||
"image": b64,
|
||||
"params": {"text_prompt": prompt, "conf_threshold": conf,
|
||||
"show_masks": True, "show_boxes": False},
|
||||
}
|
||||
try:
|
||||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if not resp.get("success"):
|
||||
return []
|
||||
shapes = resp.get("data", {}).get("shapes", [])
|
||||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||||
if scale < 1.0:
|
||||
inv = 1.0 / scale
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||||
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
||||
except Exception as e:
|
||||
print(f" SAM3 오류: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _label_to_subcat(label_str: str):
|
||||
ll = label_str.lower()
|
||||
for cat in SUB_CATEGORIES:
|
||||
if any(kw in ll for kw in cat["keywords"]):
|
||||
return cat
|
||||
return None
|
||||
|
||||
|
||||
def _subdetect_group(img: np.ndarray, group_polys: list, pad: int = 80) -> int:
|
||||
"""그룹 폴리곤들의 union bbox crop → SAM3 sub-detect → img 오버레이."""
|
||||
all_pts = np.vstack(group_polys)
|
||||
H, W = img.shape[:2]
|
||||
x0 = max(0, int(all_pts[:, 0].min()) - pad)
|
||||
y0 = max(0, int(all_pts[:, 1].min()) - pad)
|
||||
x1 = min(W, int(all_pts[:, 0].max()) + pad)
|
||||
y1 = min(H, int(all_pts[:, 1].max()) + pad)
|
||||
|
||||
crop = img[y0:y1, x0:x1].copy()
|
||||
combined_prompt = ", ".join(c["prompt"] for c in SUB_CATEGORIES)
|
||||
shapes = _sam3_call(crop, combined_prompt)
|
||||
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for s in shapes:
|
||||
cat = _label_to_subcat(s.get("label", ""))
|
||||
color = tuple(cat["color_bgr"]) if cat else (128, 128, 128)
|
||||
pts = np.array([[int(px + x0), int(py + y0)] for px, py in s["points"]], dtype=np.int32)
|
||||
overlay = img.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.30, img, 0.70, 0, img)
|
||||
cv2.polylines(img, [pts], True, color, 2)
|
||||
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
|
||||
name = cat["name"] if cat else "?"
|
||||
cv2.putText(img, name, (cx, cy), font, 0.7, color, 2, cv2.LINE_AA)
|
||||
|
||||
cv2.rectangle(img, (x0, y0), (x1, y1), (255, 200, 0), 2)
|
||||
return len(shapes)
|
||||
|
||||
|
||||
def render(image_path: Path, label_path: Path, output_path: Path,
|
||||
selected_indices=None, branch_radius: int = 10, class_ids=None,
|
||||
subdetect: bool = False, subdetect_polys: set = None):
|
||||
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
img_orig = img.copy()
|
||||
H, W = img.shape[:2]
|
||||
|
||||
text = label_path.read_text(encoding="utf-8").strip() if label_path.exists() else ""
|
||||
if class_ids is not None:
|
||||
text = "\n".join(
|
||||
ln for ln in text.splitlines()
|
||||
if ln.split() and int(ln.split()[0]) in class_ids
|
||||
)
|
||||
all_polygons = []
|
||||
for line in text.splitlines():
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
coords = list(map(float, parts[1:]))
|
||||
pts = np.array(
|
||||
[[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)],
|
||||
dtype=np.int32,
|
||||
)
|
||||
if len(pts) >= 3:
|
||||
all_polygons.append(pts)
|
||||
|
||||
if selected_indices is not None:
|
||||
keep_idx = [i for i in sorted(selected_indices) if 0 <= i < len(all_polygons)]
|
||||
else:
|
||||
keep_idx = list(range(len(all_polygons)))
|
||||
polygons = [all_polygons[i] for i in keep_idx]
|
||||
print(f" total polygons in label: {len(all_polygons)}, "
|
||||
f"selected: {len(polygons)} (indices: {keep_idx})")
|
||||
|
||||
# 수직/수평 판별: 장축 vs 이미지 중심→폴리곤 중심 radial 방향 비교
|
||||
# 드론 사선 촬영에서 수직 기둥은 이미지 중심에서 방사형으로 기울어져 보임
|
||||
img_cx, img_cy = W / 2.0, H / 2.0
|
||||
|
||||
def _poly_orient(pts):
|
||||
rect = cv2.minAreaRect(pts)
|
||||
(rx, ry), (rw, rh), angle = rect
|
||||
if min(rw, rh) < 1:
|
||||
return (160, 160, 160), '?', 0.0
|
||||
ar = max(rw, rh) / min(rw, rh)
|
||||
if ar < 1.3:
|
||||
return (160, 160, 160), '?', 0.0
|
||||
# 장축 방향 벡터
|
||||
long_angle_deg = angle if rw >= rh else angle + 90
|
||||
lx = np.cos(np.radians(long_angle_deg))
|
||||
ly = np.sin(np.radians(long_angle_deg))
|
||||
# radial 방향: 이미지 중심 → 폴리곤 중심
|
||||
rdx, rdy = rx - img_cx, ry - img_cy
|
||||
radial_norm = (rdx**2 + rdy**2) ** 0.5
|
||||
if radial_norm < 1:
|
||||
return (160, 160, 160), '?', 0.0
|
||||
rdx, rdy = rdx / radial_norm, rdy / radial_norm
|
||||
# 장축과 radial 방향의 정렬도 (|cos θ|)
|
||||
cos_sim = abs(lx * rdx + ly * rdy)
|
||||
# cos_sim → 1: 장축이 radial 방향 = 수직 기둥
|
||||
# cos_sim → 0: 장축이 radial ⊥ 방향 = 수평 빔
|
||||
if cos_sim > 0.7:
|
||||
return (255, 80, 0), 'V', cos_sim
|
||||
else:
|
||||
return (0, 80, 255), 'H', cos_sim
|
||||
|
||||
orient_map = {} # orig_idx → 'V'/'H'/'?'
|
||||
print("\n [폴리곤 수직/수평 분류]")
|
||||
for orig_idx, pts in zip(keep_idx, polygons):
|
||||
color, orient, cos_sim = _poly_orient(pts)
|
||||
orient_map[orig_idx] = orient
|
||||
print(f" poly {orig_idx:>2d}: {orient} cos_sim={cos_sim:.3f}")
|
||||
overlay = img.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.25, img, 0.75, 0, img)
|
||||
cv2.polylines(img, [pts], True, color, 2)
|
||||
cx = int(pts[:, 0].mean())
|
||||
cy = int(pts[:, 1].mean())
|
||||
label = f"{orig_idx}{orient}"
|
||||
cv2.putText(img, label, (cx, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 5)
|
||||
cv2.putText(img, label, (cx, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)
|
||||
|
||||
# 모든 폴리곤 합친 마스크 → skeleton
|
||||
merged = np.zeros((H, W), dtype=np.uint8)
|
||||
for pts in polygons:
|
||||
cv2.fillPoly(merged, [pts], 255)
|
||||
# 미세한 픽셀 overlap으로 다른 구조가 합쳐지는 것 방지: 작은 erosion
|
||||
# 5px dilation으로 인접 폴리곤 gap 브리지 후 3px erosion으로 복원
|
||||
merged_dilated = cv2.dilate(merged, np.ones((11, 11), dtype=np.uint8))
|
||||
merged_eroded = cv2.erode(merged_dilated, np.ones((7, 7), dtype=np.uint8))
|
||||
skel = cv2.ximgproc.thinning(merged_eroded)
|
||||
|
||||
# skeleton 두껍게 (시인성)
|
||||
skel_thick = cv2.dilate(skel, np.ones((5, 5), dtype=np.uint8))
|
||||
img[skel_thick > 0] = (0, 0, 0)
|
||||
|
||||
# spur pruning: 짧은 dead-end arm 제거 → 진짜 branch만 남음
|
||||
_box3 = np.ones((3, 3), dtype=np.float32)
|
||||
|
||||
def _prune_spurs(sk, min_arm_px):
|
||||
"""branch pixel 제거 후 arm component 분석, 짧은 arm 제거."""
|
||||
s_i = (sk > 0).astype(np.int32)
|
||||
sp2 = np.pad(s_i, 1, mode='constant')
|
||||
h2, w2 = s_i.shape
|
||||
n2 = [sp2[0:h2, 0:w2], sp2[0:h2, 1:w2+1], sp2[0:h2, 2:w2+2],
|
||||
sp2[1:h2+1, 2:w2+2], sp2[2:h2+2, 2:w2+2], sp2[2:h2+2, 1:w2+1],
|
||||
sp2[2:h2+2, 0:w2], sp2[1:h2+1, 0:w2]]
|
||||
cn2 = np.zeros((h2, w2), dtype=np.int32)
|
||||
for i2 in range(8):
|
||||
cn2 += (1 - n2[i2]) * n2[(i2 + 1) % 8]
|
||||
cn2 *= s_i
|
||||
arm_sk = sk.copy()
|
||||
arm_sk[cn2 >= 3] = 0 # branch pixel 제거
|
||||
n_arm, arm_lbl, arm_stats, arm_cents = cv2.connectedComponentsWithStats(arm_sk)
|
||||
sk_out = sk.copy()
|
||||
spur_cs = []
|
||||
for a_id in range(1, n_arm):
|
||||
if arm_stats[a_id, cv2.CC_STAT_AREA] < min_arm_px:
|
||||
sk_out[arm_lbl == a_id] = 0
|
||||
spur_cs.append((int(arm_cents[a_id, 0]), int(arm_cents[a_id, 1])))
|
||||
return sk_out, spur_cs
|
||||
|
||||
skel_pruned, spur_centroids = _prune_spurs(skel, min_arm_px=30)
|
||||
print(f" spur pruning: {len(spur_centroids)}개 arm 제거 (min_arm_px=30)")
|
||||
|
||||
skel_pf = (skel_pruned > 0).astype(np.float32)
|
||||
nbr_p = cv2.filter2D(skel_pf, cv2.CV_32F, _box3) - skel_pf
|
||||
deg_p = (nbr_p * skel_pf).astype(np.int32)
|
||||
|
||||
branch_mask_global = ((deg_p >= 3) & (skel_pruned > 0)).astype(np.uint8) * 255
|
||||
endpoint_mask_global = ((deg_p == 1) & (skel_pruned > 0)).astype(np.uint8) * 255
|
||||
|
||||
def cluster_centroids(mask, cluster_radius):
|
||||
if not mask.any():
|
||||
return []
|
||||
dilated = cv2.dilate(mask, np.ones((cluster_radius, cluster_radius), np.uint8))
|
||||
num, _, _, cents = cv2.connectedComponentsWithStats(dilated)
|
||||
return [(int(cents[i, 0]), int(cents[i, 1])) for i in range(1, num)]
|
||||
|
||||
def extreme_endpoint_pair(eps):
|
||||
if len(eps) < 2:
|
||||
return None
|
||||
best, best_d2 = None, -1
|
||||
for i in range(len(eps)):
|
||||
for j in range(i + 1, len(eps)):
|
||||
dx = eps[i][0] - eps[j][0]
|
||||
dy = eps[i][1] - eps[j][1]
|
||||
d2 = dx * dx + dy * dy
|
||||
if d2 > best_d2:
|
||||
best_d2 = d2
|
||||
best = (eps[i], eps[j])
|
||||
return best
|
||||
|
||||
# 그룹핑 = pruned skeleton의 connected component
|
||||
num_comp, comp_labels = cv2.connectedComponents(skel_pruned)
|
||||
|
||||
# 진단: 각 polygon이 어느 component에 속하는지 + poly_idx→cids 역매핑 빌드
|
||||
print(f"\n [Diagnostic] skeleton components: {num_comp - 1}")
|
||||
poly_to_cids: dict = {}
|
||||
for i, pts in enumerate(polygons):
|
||||
poly_mask = np.zeros((H, W), dtype=np.uint8)
|
||||
cv2.fillPoly(poly_mask, [pts], 255)
|
||||
in_skel = (skel_pruned > 0) & (poly_mask > 0)
|
||||
if in_skel.any():
|
||||
cids = sorted(set(int(c) for c in np.unique(comp_labels[in_skel]) if c > 0))
|
||||
poly_to_cids[i] = cids
|
||||
else:
|
||||
print(f" poly {i:>2d}: (skeleton 없음 — 너무 작거나 erosion됨)")
|
||||
|
||||
# 역매핑: component → polygon 목록
|
||||
cid_to_polys: dict = {}
|
||||
for poly_i, cids in poly_to_cids.items():
|
||||
for cid in cids:
|
||||
cid_to_polys.setdefault(cid, []).append(poly_i)
|
||||
|
||||
print(f"\n [그룹 → 폴리곤 목록] (전체 {len(cid_to_polys)}그룹)")
|
||||
for cid in sorted(cid_to_polys.keys()):
|
||||
n_px = int((comp_labels == cid).sum())
|
||||
polys = cid_to_polys[cid]
|
||||
orients = [f"{pi}({orient_map.get(pi, '?')})" for pi in polys]
|
||||
has_v = any(orient_map.get(pi) == 'V' for pi in polys)
|
||||
has_h = any(orient_map.get(pi) == 'H' for pi in polys)
|
||||
tag = " ★라멘?" if (has_v and has_h) else ""
|
||||
print(f" Comp {cid:>2d} ({n_px:>5d}px): {orients}{tag}")
|
||||
print()
|
||||
|
||||
# subdetect_polys가 지정된 경우, 해당 poly들이 속한 component만 sub-detect
|
||||
do_subdetect = subdetect or bool(subdetect_polys)
|
||||
target_cids: set = set()
|
||||
if subdetect_polys:
|
||||
for pi in subdetect_polys:
|
||||
for cid_ in poly_to_cids.get(pi, []):
|
||||
target_cids.add(cid_)
|
||||
print(f" sub-detect 대상 그룹: {sorted(target_cids)} (poly {sorted(subdetect_polys)} 기준)\n")
|
||||
|
||||
# cid → 해당 component에 속하는 polygon pts 목록 (subdetect용)
|
||||
comp_to_polys: dict = {}
|
||||
if do_subdetect:
|
||||
for pts in polygons:
|
||||
pmask = np.zeros((H, W), dtype=np.uint8)
|
||||
cv2.fillPoly(pmask, [pts], 255)
|
||||
in_s = (skel_pruned > 0) & (pmask > 0)
|
||||
if in_s.any():
|
||||
for cid_ in set(int(c) for c in comp_labels[in_s] if c > 0):
|
||||
comp_to_polys.setdefault(cid_, []).append(pts)
|
||||
|
||||
total_branches = 0
|
||||
total_endpoints = 0
|
||||
total_groups = 0
|
||||
total_lines = 0
|
||||
|
||||
for cid in range(1, num_comp):
|
||||
comp_mask = (comp_labels == cid)
|
||||
if int(comp_mask.sum()) < 50:
|
||||
continue
|
||||
total_groups += 1
|
||||
|
||||
comp_skel = (skel_pruned * comp_mask.astype(np.uint8)).astype(np.uint8)
|
||||
comp_branch = cv2.bitwise_and(branch_mask_global, comp_skel)
|
||||
comp_endpoint = cv2.bitwise_and(endpoint_mask_global, comp_skel)
|
||||
|
||||
branches = cluster_centroids(comp_branch, 7)
|
||||
endpoints = cluster_centroids(comp_endpoint, 5)
|
||||
|
||||
# Virtual endpoints: endpoint 3개인 component에만 적용 (라멘에서 1개 누락된 경우).
|
||||
# 알고리즘:
|
||||
# 1) skeleton bbox 계산 → 4 변 (top/bottom/left/right)
|
||||
# 2) 각 endpoint를 가장 가까운 변에 배정
|
||||
# 3) 4 변 중 endpoint 없는 "빠진 방향" 찾음
|
||||
# 4) 빠진 방향에 따라 branch 중 극값(가장 X/Y가 작거나 큰)을 virtual endpoint로
|
||||
virtual_endpoints = []
|
||||
if len(endpoints) == 3 and branches:
|
||||
ys_arr, xs_arr = np.where(comp_skel > 0)
|
||||
if len(ys_arr) > 0:
|
||||
xmin, xmax = int(xs_arr.min()), int(xs_arr.max())
|
||||
ymin, ymax = int(ys_arr.min()), int(ys_arr.max())
|
||||
# 각 endpoint를 가장 가까운 변에 배정
|
||||
occupied = set()
|
||||
for ex, ey in endpoints:
|
||||
d_top = ey - ymin
|
||||
d_bot = ymax - ey
|
||||
d_left = ex - xmin
|
||||
d_right = xmax - ex
|
||||
side, _ = min(
|
||||
(('top', d_top), ('bottom', d_bot),
|
||||
('left', d_left), ('right', d_right)),
|
||||
key=lambda x: x[1])
|
||||
occupied.add(side)
|
||||
# 빠진 방향들
|
||||
all_sides = {'top', 'bottom', 'left', 'right'}
|
||||
missing_sides = all_sides - occupied
|
||||
# 각 빠진 방향마다 극값 branch 찾기
|
||||
for ms in missing_sides:
|
||||
if ms == 'right':
|
||||
target = max(branches, key=lambda b: b[0])
|
||||
elif ms == 'left':
|
||||
target = min(branches, key=lambda b: b[0])
|
||||
elif ms == 'top':
|
||||
target = min(branches, key=lambda b: b[1])
|
||||
elif ms == 'bottom':
|
||||
target = max(branches, key=lambda b: b[1])
|
||||
too_close = any((target[0] - px) ** 2 + (target[1] - py) ** 2 < 80 ** 2
|
||||
for px, py in endpoints + virtual_endpoints)
|
||||
if not too_close:
|
||||
virtual_endpoints.append(target)
|
||||
print(f" virtual endpoint: missing={list(missing_sides)}, "
|
||||
f"added={virtual_endpoints}")
|
||||
|
||||
# 통합 endpoint 리스트 (real + virtual)
|
||||
endpoints_all = endpoints + virtual_endpoints
|
||||
total_branches += len(branches)
|
||||
total_endpoints += len(endpoints_all)
|
||||
|
||||
# ── 새 알고리즘: 최상단 노드 기반 topology ──────────────────────
|
||||
if len(endpoints_all) < 2:
|
||||
pass
|
||||
else:
|
||||
def node_dist(a, b):
|
||||
return ((a[0]-b[0])**2 + (a[1]-b[1])**2) ** 0.5
|
||||
|
||||
# 1. 최상단 노드 T (y 최솟값)
|
||||
top_i = min(range(len(endpoints_all)),
|
||||
key=lambda i: (endpoints_all[i][1], endpoints_all[i][0]))
|
||||
T = endpoints_all[top_i]
|
||||
other_idx = [i for i in range(len(endpoints_all)) if i != top_i]
|
||||
|
||||
connected_idx = {top_i}
|
||||
lines_to_draw = []
|
||||
|
||||
if len(other_idx) == 1:
|
||||
j = other_idx[0]
|
||||
lines_to_draw.append((T, endpoints_all[j]))
|
||||
connected_idx.add(j)
|
||||
longer_i = j
|
||||
else:
|
||||
# 2. Left(min x), Right(max x) → T에서 이 둘에만 연결
|
||||
left_i = min(other_idx, key=lambda i: endpoints_all[i][0])
|
||||
right_i = max(other_idx, key=lambda i: endpoints_all[i][0])
|
||||
lines_to_draw.append((T, endpoints_all[left_i]))
|
||||
lines_to_draw.append((T, endpoints_all[right_i]))
|
||||
connected_idx.update([left_i, right_i])
|
||||
|
||||
# 3. 긴 쪽 결정
|
||||
d_l = node_dist(T, endpoints_all[left_i])
|
||||
d_r = node_dist(T, endpoints_all[right_i])
|
||||
longer_i = left_i if d_l >= d_r else right_i
|
||||
|
||||
# 4. 체이닝: 긴 쪽 끝점에서 y+ 방향 미연결 노드로 계속 연결
|
||||
current_i = longer_i
|
||||
while True:
|
||||
cur = endpoints_all[current_i]
|
||||
below = [(i, node_dist(cur, endpoints_all[i]))
|
||||
for i in range(len(endpoints_all))
|
||||
if i not in connected_idx and endpoints_all[i][1] > cur[1]]
|
||||
if not below:
|
||||
break
|
||||
next_i = min(below, key=lambda x: x[1])[0]
|
||||
lines_to_draw.append((cur, endpoints_all[next_i]))
|
||||
connected_idx.add(next_i)
|
||||
current_i = next_i
|
||||
|
||||
for p1, p2 in lines_to_draw:
|
||||
cv2.line(img, p1, p2, (0, 255, 255), 5)
|
||||
total_lines += 1
|
||||
print(f" line {p1}→{p2}: {node_dist(p1,p2):.0f}px")
|
||||
|
||||
# 노드 표시 (line 위에 올라오게)
|
||||
for cx, cy in branches:
|
||||
cv2.circle(img, (cx, cy), 14, (0, 220, 0), 3) # 초록 중간 원 (진짜 T/X junction)
|
||||
for cx, cy in endpoints:
|
||||
cv2.circle(img, (cx, cy), 18, (0, 255, 0), 4)
|
||||
# virtual endpoint = 오렌지 원
|
||||
for cx, cy in virtual_endpoints:
|
||||
cv2.circle(img, (cx, cy), 18, (0, 140, 255), 4)
|
||||
|
||||
should_sub = do_subdetect and cid in comp_to_polys
|
||||
if should_sub and target_cids:
|
||||
should_sub = cid in target_cids
|
||||
if should_sub:
|
||||
n = _subdetect_group(img, comp_to_polys[cid])
|
||||
print(f" 그룹 {cid}: sub-detect {n}개")
|
||||
|
||||
# 제거된 spur arm centroids = 분홍 작은 원
|
||||
for cx, cy in spur_centroids:
|
||||
cv2.circle(img, (cx, cy), 6, (180, 120, 180), 2)
|
||||
|
||||
print(f" {len(polygons)} polygons, skeleton 픽셀 {int((skel > 0).sum())}")
|
||||
print(f" 그룹(skeleton component) {total_groups}개")
|
||||
print(f" 분기점 {total_branches}개, 끝점 {total_endpoints}개, 연결선 {total_lines}개")
|
||||
|
||||
# 그룹 시각화: comp별 색상으로 폴리곤 표시 (원본 이미지 위)
|
||||
_PALETTE = [
|
||||
(255, 80, 80), ( 80, 220, 80), ( 80, 120, 255), (220, 200, 50),
|
||||
(220, 80, 220), ( 60, 220, 220), (255, 160, 60), (100, 200, 120),
|
||||
(180, 80, 255), (255, 140, 100),
|
||||
]
|
||||
_poly2comp = {pi: cid for cid, pis in cid_to_polys.items() for pi in pis}
|
||||
_idx2pos = {orig: i for i, orig in enumerate(keep_idx)}
|
||||
|
||||
img = img_orig.copy()
|
||||
for orig_idx, pts in zip(keep_idx, polygons):
|
||||
cid = _poly2comp.get(orig_idx)
|
||||
col = _PALETTE[(cid - 1) % len(_PALETTE)] if cid else (140, 140, 140)
|
||||
ov = img.copy()
|
||||
cv2.fillPoly(ov, [pts], col)
|
||||
cv2.addWeighted(ov, 0.35, img, 0.65, 0, img)
|
||||
cv2.polylines(img, [pts], True, col, 3)
|
||||
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
|
||||
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4)
|
||||
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2)
|
||||
|
||||
for cid, pis in cid_to_polys.items():
|
||||
col = _PALETTE[(cid - 1) % len(_PALETTE)]
|
||||
valid = [_idx2pos[pi] for pi in pis if pi in _idx2pos]
|
||||
if not valid:
|
||||
continue
|
||||
cx = int(sum(int(polygons[i][:, 0].mean()) for i in valid) / len(valid))
|
||||
cy = int(sum(int(polygons[i][:, 1].mean()) for i in valid) / len(valid))
|
||||
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 7)
|
||||
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, col, 3)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
scale = min(1.0, 4096 / max(h, w))
|
||||
if scale < 1.0:
|
||||
img = cv2.resize(img, (int(w * scale), int(h * scale)))
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path))
|
||||
print(f" → {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--image", required=True)
|
||||
ap.add_argument("--label", required=True)
|
||||
ap.add_argument("--output", required=True)
|
||||
ap.add_argument("--polygons", type=str, default="",
|
||||
help="comma-separated polygon indices (예: '10,11,12,26,27'). 비우면 전체")
|
||||
ap.add_argument("--class-ids", type=str, default="",
|
||||
help="포함할 클래스 ID (예: '1' or '1,3'). 비우면 전체")
|
||||
ap.add_argument("--branch-radius", type=int, default=10,
|
||||
help="branch 판정 원 반지름 px (기본 10)")
|
||||
ap.add_argument("--subdetect", action="store_true",
|
||||
help="모든 그룹 sub-detect (SAM3 서버 필요)")
|
||||
ap.add_argument("--subdetect-polys", type=str, default="",
|
||||
help="지정 polygon이 속한 그룹만 sub-detect (예: '2,27')")
|
||||
args = ap.parse_args()
|
||||
selected = None
|
||||
if args.polygons.strip():
|
||||
selected = set(int(x.strip()) for x in args.polygons.split(',') if x.strip())
|
||||
class_ids = None
|
||||
if args.class_ids.strip():
|
||||
class_ids = set(int(x.strip()) for x in args.class_ids.split(',') if x.strip())
|
||||
subdetect_polys = None
|
||||
if args.subdetect_polys.strip():
|
||||
subdetect_polys = set(int(x.strip()) for x in args.subdetect_polys.split(',') if x.strip())
|
||||
render(Path(args.image), Path(args.label), Path(args.output), selected,
|
||||
branch_radius=args.branch_radius, class_ids=class_ids,
|
||||
subdetect=args.subdetect, subdetect_polys=subdetect_polys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
777
tools/sam3_autolabel.py
Normal file
777
tools/sam3_autolabel.py
Normal file
@@ -0,0 +1,777 @@
|
||||
"""
|
||||
SAM 3.1 텍스트 프롬프트 자동 라벨링
|
||||
=====================================================
|
||||
알고리즘:
|
||||
A = 전철주 프롬프트(prompts/pole.txt) 검출 → 높이 > min_pole_h 필터
|
||||
B = 철로 프롬프트(prompts/rail.txt) 검출 → zone mask 생성
|
||||
결과 = A 중 B zone 내/근처에 있는 것만 유지
|
||||
|
||||
사용법:
|
||||
python tools/sam3_autolabel.py --input data/역사이미지/slope/
|
||||
python tools/sam3_autolabel.py --input data/역사이미지/ --output output/autolabel/
|
||||
|
||||
사전 조건:
|
||||
- start_server.bat 실행 (SAM 3.1 서버)
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
SAM3_MODEL_ID = "segment_anything_3"
|
||||
|
||||
_PROMPT_DIR = Path(__file__).parent.parent / "prompts"
|
||||
|
||||
CLASS_NAMES = ["catenary_pole", "bracket"]
|
||||
_CLS_COLORS = [(0, 200, 255), (255, 130, 0)]
|
||||
|
||||
|
||||
def _load_prompt(filename: str, default: str) -> str:
|
||||
path = _PROMPT_DIR / filename
|
||||
if path.exists():
|
||||
lines = [l.strip() for l in path.read_text(encoding="utf-8").splitlines()
|
||||
if l.strip() and not l.startswith("#")]
|
||||
if lines:
|
||||
return ", ".join(lines)
|
||||
return default
|
||||
|
||||
|
||||
def _get_prompts() -> tuple:
|
||||
"""(pole_prompt, rail_prompt, bracket_prompt) 반환."""
|
||||
pole = _load_prompt("pole.txt",
|
||||
"railway catenary pole, overhead line support pole, catenary mast")
|
||||
rail = _load_prompt("rail.txt", "railroad, railway")
|
||||
bracket = _load_prompt("bracket.txt",
|
||||
"catenary pole top, pole cross arm, horizontal cantilever beam, bracket arm")
|
||||
return pole, rail, bracket
|
||||
|
||||
|
||||
# ── 이미지 인코딩 ──────────────────────────────────────────────────────────────
|
||||
def encode_image(image_bgr: np.ndarray, max_size: int = 2048) -> tuple:
|
||||
h, w = image_bgr.shape[:2]
|
||||
scale = 1.0
|
||||
if max_size > 0 and max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
|
||||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
return base64.b64encode(buf).decode("utf-8"), scale
|
||||
|
||||
|
||||
# ── SAM 3.1 텍스트 프롬프트 세그멘테이션 ──────────────────────────────────────
|
||||
def sam3_text_segment(image_bgr: np.ndarray, text_prompt: str,
|
||||
conf_threshold: float = 0.25,
|
||||
sam_max_size: int = 2048) -> list:
|
||||
b64, scale = encode_image(image_bgr, sam_max_size)
|
||||
payload = {
|
||||
"model": SAM3_MODEL_ID,
|
||||
"image": b64,
|
||||
"params": {
|
||||
"text_prompt": text_prompt,
|
||||
"conf_threshold": conf_threshold,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
},
|
||||
}
|
||||
try:
|
||||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=180)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if not resp.get("success"):
|
||||
print(f" SAM3 오류: {resp.get('error', {}).get('message', 'unknown')}")
|
||||
return []
|
||||
shapes = resp.get("data", {}).get("shapes", [])
|
||||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||||
if scale < 1.0:
|
||||
inv = 1.0 / scale
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||||
return shapes
|
||||
except Exception as e:
|
||||
print(f" SAM3 오류: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def shapes_to_pairs(shapes: list) -> list:
|
||||
"""polygon shapes -> (det, shape) pairs. det = (label, score, x1,y1,x2,y2)."""
|
||||
pairs = []
|
||||
for s in shapes:
|
||||
if s.get("shape_type") != "polygon":
|
||||
continue
|
||||
pts = s.get("points", [])
|
||||
if len(pts) < 3:
|
||||
continue
|
||||
xs = [p[0] for p in pts]
|
||||
ys = [p[1] for p in pts]
|
||||
det = (s.get("label", ""), float(s.get("score", 0.0)),
|
||||
min(xs), min(ys), max(xs), max(ys))
|
||||
pairs.append((det, s))
|
||||
return pairs
|
||||
|
||||
|
||||
# ── NMS ───────────────────────────────────────────────────────────────────────
|
||||
def nms(pairs: list, iou_thresh: float = 0.5) -> list:
|
||||
if not pairs:
|
||||
return []
|
||||
boxes = np.array([[x1, y1, x2, y2] for (_, _, x1, y1, x2, y2), _ in pairs])
|
||||
scores = np.array([conf for (_, conf, *_), _ in pairs])
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while len(order):
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
if len(order) == 1:
|
||||
break
|
||||
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
|
||||
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
|
||||
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
|
||||
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
|
||||
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
|
||||
a_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
|
||||
a_j = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
|
||||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||||
order = order[1:][iou < iou_thresh]
|
||||
return [pairs[i] for i in keep]
|
||||
|
||||
|
||||
# ── 높이 필터 ─────────────────────────────────────────────────────────────────
|
||||
def filter_by_size(pairs: list, min_h: int, debug: bool = False) -> list:
|
||||
"""전철주 높이 필터: det bbox 높이 >= min_h, SAM mask도 충분히 높고 가늘어야 함."""
|
||||
kept = []
|
||||
for (label, conf, x1, y1, x2, y2), shape in pairs:
|
||||
det_h = y2 - y1
|
||||
if det_h < min_h:
|
||||
if debug:
|
||||
print(f" DROP [{label} {conf:.2f}] h={det_h:.0f} (최소:{min_h})")
|
||||
continue
|
||||
pts = shape.get("points", [])
|
||||
if pts:
|
||||
ys = [p[1] for p in pts]
|
||||
xs = [p[0] for p in pts]
|
||||
mh = max(ys) - min(ys)
|
||||
mw = max(xs) - min(xs)
|
||||
if mh < min_h * 0.7:
|
||||
if debug:
|
||||
print(f" DROP [{label} {conf:.2f}] sam_h={mh:.0f} (최소:{min_h*0.7:.0f})")
|
||||
continue
|
||||
if mw > mh * 3:
|
||||
if debug:
|
||||
print(f" DROP [{label} {conf:.2f}] 너무 넓음 w{mw:.0f}>h{mh:.0f}x3")
|
||||
continue
|
||||
kept.append(((label, conf, x1, y1, x2, y2), shape))
|
||||
return kept
|
||||
|
||||
|
||||
# ── 철로 존 마스크 ─────────────────────────────────────────────────────────────
|
||||
def build_zone_mask(shapes: list, H: int, W: int, margin: int,
|
||||
gap_close: int = 500,
|
||||
lateral_expand_ratio: float = 0.0) -> np.ndarray:
|
||||
"""철로 polygon → binary mask → 수평 갭 보정(beam 차폐) → margin px 팽창.
|
||||
lateral_expand_ratio > 0 이면 추가로 좌우 수평 팽창 (가장자리 기둥 포착)."""
|
||||
mask = np.zeros((H, W), dtype=np.uint8)
|
||||
for s in shapes:
|
||||
pts = np.array(s["points"], dtype=np.int32)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
if gap_close > 0 and mask.any():
|
||||
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (gap_close, 30))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, h_kernel)
|
||||
if margin > 0 and mask.any():
|
||||
r = margin
|
||||
yy, xx = np.ogrid[-r:r + 1, -r:r + 1]
|
||||
kernel = (xx * xx + yy * yy <= r * r).astype(np.uint8)
|
||||
mask = cv2.dilate(mask, kernel)
|
||||
if lateral_expand_ratio > 0 and mask.any():
|
||||
lat = int(W * lateral_expand_ratio)
|
||||
if lat > 0:
|
||||
lat_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (lat * 2 + 1, 1))
|
||||
mask = cv2.dilate(mask, lat_kernel)
|
||||
return mask
|
||||
|
||||
|
||||
def pole_in_zone(pole_det: tuple, zone_mask: np.ndarray) -> bool:
|
||||
"""전철주 bbox 중심 열(column)의 y1~y2 범위 중 zone_mask와 교차하면 True.
|
||||
기둥은 철로 근처에 기저부를 두고 위/아래로 뻗으므로
|
||||
중심 한 점 대신 세로 전체 범위로 판단."""
|
||||
_, _, x1, y1, x2, y2 = pole_det
|
||||
H, W = zone_mask.shape
|
||||
cx = int(max(0, min((x1 + x2) / 2, W - 1)))
|
||||
ys = int(max(0, y1))
|
||||
ye = int(min(H, y2 + 1))
|
||||
return bool(zone_mask[ys:ye, cx].any())
|
||||
|
||||
|
||||
def _filter_pole_has_beam(pole_pairs: list, beam_pairs: list) -> list:
|
||||
"""A ∩ C: 전철주 bbox와 겹치는 Beam/Arm이 하나라도 있는 기둥만 유지.
|
||||
Beam/Arm bbox와 기둥 bbox의 IoU > 0 (단순 교차)으로 판단."""
|
||||
beam_boxes = [(x1, y1, x2, y2)
|
||||
for (_, _, x1, y1, x2, y2), _ in beam_pairs]
|
||||
|
||||
def overlaps_any_beam(pole_det):
|
||||
_, _, px1, py1, px2, py2 = pole_det
|
||||
for bx1, by1, bx2, by2 in beam_boxes:
|
||||
ix1 = max(px1, bx1); iy1 = max(py1, by1)
|
||||
ix2 = min(px2, bx2); iy2 = min(py2, by2)
|
||||
if ix2 > ix1 and iy2 > iy1:
|
||||
return True
|
||||
return False
|
||||
|
||||
return [(d, s) for d, s in pole_pairs if overlaps_any_beam(d)]
|
||||
|
||||
|
||||
def filter_long_beams(beam_pairs: list, rail_shapes: list,
|
||||
min_beam_width: int = 400,
|
||||
ratio_thresh: float = 2.0,
|
||||
rail_x_margin: int = 500,
|
||||
debug: bool = False) -> list:
|
||||
"""역사형 portal frame의 긴 빔만 유지.
|
||||
조건:
|
||||
1) bbox 폭 >= min_beam_width (가로등은 좁음)
|
||||
2) 폭/높이 >= ratio_thresh (가로형이어야 함)
|
||||
3) 빔 x-범위가 검출된 rail x-범위 ±rail_x_margin 내 (도로 차단)
|
||||
"""
|
||||
if not rail_shapes:
|
||||
if debug:
|
||||
print(f" [Beam filter] rail 미검출 → 빔 zone 보강 비활성")
|
||||
return []
|
||||
|
||||
rail_xs = []
|
||||
for s in rail_shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
xs = [p[0] for p in s["points"]]
|
||||
rail_xs.extend([min(xs), max(xs)])
|
||||
if not rail_xs:
|
||||
return []
|
||||
rail_x_min, rail_x_max = min(rail_xs), max(rail_xs)
|
||||
|
||||
kept = []
|
||||
for (label, conf, x1, y1, x2, y2), shape in beam_pairs:
|
||||
bw = x2 - x1
|
||||
bh = max(1, y2 - y1)
|
||||
ratio = bw / bh
|
||||
if bw < min_beam_width:
|
||||
if debug:
|
||||
print(f" DROP beam [{label} {conf:.2f}] w={bw:.0f} (최소:{min_beam_width})")
|
||||
continue
|
||||
if ratio < ratio_thresh:
|
||||
if debug:
|
||||
print(f" DROP beam [{label} {conf:.2f}] w/h={ratio:.2f} (최소:{ratio_thresh})")
|
||||
continue
|
||||
if x2 < rail_x_min - rail_x_margin or x1 > rail_x_max + rail_x_margin:
|
||||
if debug:
|
||||
print(f" DROP beam [{label} {conf:.2f}] rail x-범위 벗어남")
|
||||
continue
|
||||
kept.append(((label, conf, x1, y1, x2, y2), shape))
|
||||
return kept
|
||||
|
||||
|
||||
def add_beams_to_zone(zone_mask: np.ndarray, beam_pairs: list,
|
||||
downward_extend: int = 400, side_margin: int = 100) -> np.ndarray:
|
||||
"""필터된 긴 빔의 bbox를 zone_mask에 추가.
|
||||
빔 위치 + 아래쪽 downward_extend px → 레일이 빔에 차폐된 구간 보강."""
|
||||
H, W = zone_mask.shape
|
||||
for (_, _, x1, y1, x2, y2), _ in beam_pairs:
|
||||
bx1 = max(0, int(x1) - side_margin)
|
||||
bx2 = min(W, int(x2) + side_margin)
|
||||
by1 = max(0, int(y1))
|
||||
by2 = min(H, int(y2) + downward_extend)
|
||||
zone_mask[by1:by2, bx1:bx2] = 255
|
||||
return zone_mask
|
||||
|
||||
|
||||
def _mask_to_largest_polygon(mask: np.ndarray, min_area: int = 50) -> list:
|
||||
"""이진 마스크에서 가장 큰 외곽선 → polygon points."""
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return None
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
if cv2.contourArea(largest) < min_area:
|
||||
return None
|
||||
epsilon = 0.002 * cv2.arcLength(largest, True)
|
||||
approx = cv2.approxPolyDP(largest, epsilon, True)
|
||||
return [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
|
||||
|
||||
def split_pole_arm(pts: list, H: int, W: int,
|
||||
arm_min_factor: float = 0.5) -> tuple:
|
||||
"""전철주 polygon을 세로 기둥(column)과 가로 암(arm)으로 분리.
|
||||
|
||||
방법 (row-width 분석):
|
||||
1. polygon → mask
|
||||
2. 각 가로 row에서 mask width 측정
|
||||
3. 좁은 row 절반의 중앙값 = column 폭 추정
|
||||
4. column 중심 x = 좁은 row들의 mask 중점 중앙값
|
||||
5. column mask = column 중심 ± column 폭/2 범위 내 mask
|
||||
6. arm mask = column 범위 외 mask (충분히 가로로 뻗은 row만)
|
||||
|
||||
arm_min_factor: arm으로 인정할 row의 최소 width = column 폭 × factor
|
||||
반환: (column_pts, arm_pts | None)
|
||||
"""
|
||||
polygon = np.array(pts, dtype=np.int32)
|
||||
if len(polygon) < 3:
|
||||
return pts, None
|
||||
|
||||
mask = np.zeros((H, W), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [polygon], 255)
|
||||
|
||||
x1 = max(0, int(polygon[:, 0].min()))
|
||||
x2 = min(W - 1, int(polygon[:, 0].max()))
|
||||
y1 = max(0, int(polygon[:, 1].min()))
|
||||
y2 = min(H - 1, int(polygon[:, 1].max()))
|
||||
sub = mask[y1:y2 + 1, x1:x2 + 1]
|
||||
|
||||
row_widths = (sub > 0).sum(axis=1).astype(np.float32)
|
||||
valid = row_widths > 0
|
||||
if valid.sum() < 5:
|
||||
return pts, None
|
||||
|
||||
# column 폭 = 좁은 row들의 중앙값 (전체의 좁은 절반)
|
||||
sorted_widths = np.sort(row_widths[valid])
|
||||
column_w = float(np.median(sorted_widths[:max(1, len(sorted_widths) // 2)]))
|
||||
if column_w < 5:
|
||||
return pts, None
|
||||
|
||||
# column 중심 x 추정 — 좁은 row들의 mask 중점 중앙값
|
||||
narrow_threshold = column_w * 1.3
|
||||
cx_list = []
|
||||
for i, w in enumerate(row_widths):
|
||||
if 0 < w <= narrow_threshold:
|
||||
row_px = np.where(sub[i] > 0)[0]
|
||||
if len(row_px) > 0:
|
||||
cx_list.append((row_px[0] + row_px[-1]) / 2)
|
||||
if not cx_list:
|
||||
return pts, None
|
||||
column_cx = int(np.median(cx_list)) + x1
|
||||
column_half = max(4, int(column_w * 0.7))
|
||||
|
||||
col_x_min = max(0, column_cx - column_half)
|
||||
col_x_max = min(W, column_cx + column_half + 1)
|
||||
|
||||
column_mask = np.zeros_like(mask)
|
||||
column_mask[:, col_x_min:col_x_max] = mask[:, col_x_min:col_x_max]
|
||||
|
||||
arm_mask = mask.copy()
|
||||
arm_mask[:, col_x_min:col_x_max] = 0
|
||||
|
||||
# arm: 가로로 충분히 뻗은 row만 유지
|
||||
arm_min_w = max(20, int(column_w * arm_min_factor))
|
||||
arm_row_widths = (arm_mask > 0).sum(axis=1)
|
||||
for y in range(H):
|
||||
if arm_row_widths[y] < arm_min_w:
|
||||
arm_mask[y, :] = 0
|
||||
|
||||
column_pts = _mask_to_largest_polygon(column_mask)
|
||||
arm_pts = _mask_to_largest_polygon(arm_mask) if arm_mask.any() else None
|
||||
|
||||
# arm이 너무 작으면 폐기 (column 면적의 5% 미만)
|
||||
if arm_pts and column_pts:
|
||||
c_area = cv2.contourArea(np.array(column_pts, dtype=np.int32))
|
||||
a_area = cv2.contourArea(np.array(arm_pts, dtype=np.int32))
|
||||
if a_area < c_area * 0.05:
|
||||
arm_pts = None
|
||||
|
||||
return (column_pts or pts), arm_pts
|
||||
|
||||
|
||||
# ── 후처리: 합치기 + Skeleton 분리 (column / beam) ────────────────────────────
|
||||
def _compute_skeleton(mask: np.ndarray) -> np.ndarray:
|
||||
"""이진 마스크 → 1px 두께 skeleton (opencv-contrib-python의 ximgproc.thinning)."""
|
||||
return cv2.ximgproc.thinning(mask)
|
||||
|
||||
|
||||
def _find_branch_points(skel: np.ndarray) -> np.ndarray:
|
||||
"""Skeleton의 분기점(8-neighbor degree ≥ 3) 마스크 반환."""
|
||||
skel_bin = (skel > 0).astype(np.uint8)
|
||||
# 8-neighbor 합 = 3x3 box filter * 9 - center
|
||||
box_kernel = np.ones((3, 3), dtype=np.float32)
|
||||
nbr = cv2.filter2D(skel_bin.astype(np.float32), cv2.CV_32F, box_kernel) - skel_bin
|
||||
branch = (skel_bin == 1) & (nbr >= 3)
|
||||
return branch.astype(np.uint8) * 255
|
||||
|
||||
|
||||
def _compute_degree_map(skel: np.ndarray) -> np.ndarray:
|
||||
"""Skeleton 각 픽셀의 8-neighbor degree (자기 자신 제외)."""
|
||||
skel_bin = (skel > 0).astype(np.float32)
|
||||
box = np.ones((3, 3), dtype=np.float32)
|
||||
nbr = cv2.filter2D(skel_bin, cv2.CV_32F, box) - skel_bin
|
||||
deg = (nbr * skel_bin).astype(np.int32)
|
||||
return deg
|
||||
|
||||
|
||||
def _split_one_component(mask: np.ndarray) -> tuple:
|
||||
"""하나의 connected component를 column / beam mask로 분리 (위상 규칙).
|
||||
|
||||
라멘 구조 위상:
|
||||
- beam (가로 빔): 양 끝이 branch에 연결 → 원본 free endpoint 0개
|
||||
- column (세로 기둥): 한 끝이 branch, 다른 끝이 free → 원본 free endpoint 1개+
|
||||
|
||||
위상 규칙이 적용되지 않는 폴리곤(단일 segment, 가로등 'ㅏ', "ㄱ" 등)은
|
||||
라멘이 아니므로 (None, None) 반환 → 호출부에서 폴리곤 삭제.
|
||||
"""
|
||||
skel = _compute_skeleton(mask)
|
||||
if not skel.any():
|
||||
return None, None
|
||||
|
||||
# 원본 skeleton의 degree map
|
||||
deg = _compute_degree_map(skel)
|
||||
free_endpoint_mask = ((deg == 1) & (skel > 0))
|
||||
branch_mask = ((deg >= 3) & (skel > 0)).astype(np.uint8) * 255
|
||||
|
||||
# branch 제거 → segment 분리
|
||||
if branch_mask.any():
|
||||
bk = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
|
||||
seg_skel = cv2.bitwise_and(skel, cv2.bitwise_not(cv2.dilate(branch_mask, bk)))
|
||||
else:
|
||||
# branch가 전혀 없음 = 단일 선/곡선 → 라멘 아님
|
||||
return None, None
|
||||
|
||||
num, labels = cv2.connectedComponents(seg_skel)
|
||||
interior_seeds = [] # free endpoint 0개 segment
|
||||
edge_seeds = [] # free endpoint 1개+ segment
|
||||
|
||||
for sid in range(1, num):
|
||||
seg = (labels == sid)
|
||||
if seg.sum() < 10:
|
||||
continue
|
||||
n_free = int((free_endpoint_mask & seg).sum())
|
||||
seg_u8 = seg.astype(np.uint8) * 255
|
||||
if n_free == 0:
|
||||
interior_seeds.append(seg_u8)
|
||||
else:
|
||||
edge_seeds.append(seg_u8)
|
||||
|
||||
# 라멘이 되려면 interior(beam)와 edge(column)가 모두 있어야 함
|
||||
if not interior_seeds or not edge_seeds:
|
||||
return None, None
|
||||
|
||||
col_seed = np.zeros_like(mask)
|
||||
beam_seed = np.zeros_like(mask)
|
||||
for s in edge_seeds:
|
||||
col_seed = cv2.bitwise_or(col_seed, s)
|
||||
for s in interior_seeds:
|
||||
beam_seed = cv2.bitwise_or(beam_seed, s)
|
||||
|
||||
col_dist = cv2.distanceTransform(255 - col_seed, cv2.DIST_L2, 5)
|
||||
beam_dist = cv2.distanceTransform(255 - beam_seed, cv2.DIST_L2, 5)
|
||||
|
||||
mask_bool = mask > 0
|
||||
col_region = (mask_bool & (col_dist <= beam_dist)).astype(np.uint8) * 255
|
||||
beam_region = (mask_bool & (col_dist > beam_dist)).astype(np.uint8) * 255
|
||||
return col_region, beam_region
|
||||
|
||||
|
||||
def _extract_polygons(mask: np.ndarray, min_area: int = 200,
|
||||
epsilon_factor: float = 0.002) -> list:
|
||||
"""Mask의 모든 connected component → polygon points 리스트."""
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
polys = []
|
||||
for c in contours:
|
||||
if cv2.contourArea(c) < min_area:
|
||||
continue
|
||||
eps = epsilon_factor * cv2.arcLength(c, True)
|
||||
approx = cv2.approxPolyDP(c, eps, True)
|
||||
pts = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
if len(pts) >= 3:
|
||||
polys.append(pts)
|
||||
return polys
|
||||
|
||||
|
||||
def skeleton_split_columns_beams(pole_pairs: list, H: int, W: int,
|
||||
min_component_area: int = 1000,
|
||||
min_polygon_area: int = 200) -> tuple:
|
||||
"""Pole polygon별로 skeleton split → column / beam 분류된 polygon 추출.
|
||||
|
||||
반환: (column_polygons, beam_polygons)
|
||||
"""
|
||||
if not pole_pairs:
|
||||
return [], []
|
||||
|
||||
merged = np.zeros((H, W), dtype=np.uint8)
|
||||
for _, shape in pole_pairs:
|
||||
if shape.get("shape_type") == "polygon":
|
||||
pts = np.array(shape["points"], dtype=np.int32)
|
||||
cv2.fillPoly(merged, [pts], 255)
|
||||
if not merged.any():
|
||||
return [], []
|
||||
|
||||
# closing 미적용 — SAM3.1 polygon 형상 그대로 사용
|
||||
num, labels = cv2.connectedComponents(merged)
|
||||
all_col = np.zeros_like(merged)
|
||||
all_beam = np.zeros_like(merged)
|
||||
kept = 0
|
||||
dropped = 0
|
||||
for cid in range(1, num):
|
||||
comp = (labels == cid).astype(np.uint8) * 255
|
||||
if int(comp.sum() / 255) < min_component_area:
|
||||
continue
|
||||
col_m, beam_m = _split_one_component(comp)
|
||||
if col_m is None or beam_m is None:
|
||||
dropped += 1
|
||||
continue
|
||||
all_col = cv2.bitwise_or(all_col, col_m)
|
||||
all_beam = cv2.bitwise_or(all_beam, beam_m)
|
||||
kept += 1
|
||||
print(f" [Topo] 라멘 위상 통과: {kept}개, 비라멘 폐기: {dropped}개")
|
||||
|
||||
column_polys = _extract_polygons(all_col, min_area=min_polygon_area)
|
||||
beam_polys = _extract_polygons(all_beam, min_area=min_polygon_area)
|
||||
return column_polys, beam_polys
|
||||
|
||||
|
||||
# ── 저장 ──────────────────────────────────────────────────────────────────────
|
||||
def save_yolo_label(label_path: Path, pairs: list, W: int, H: int):
|
||||
lines = []
|
||||
for (label, conf, *_), shape in pairs:
|
||||
if shape.get("shape_type") != "polygon":
|
||||
continue
|
||||
pts = shape.get("points", [])
|
||||
if len(pts) < 3:
|
||||
continue
|
||||
cls_id = shape.get("_cls_id", 0)
|
||||
coords = " ".join(f"{x/W:.6f} {y/H:.6f}" for x, y in pts)
|
||||
lines.append(f"{cls_id} {coords}")
|
||||
label_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
|
||||
def save_vis(vis_path: Path, image_bgr: np.ndarray, pairs: list,
|
||||
rail_shapes: list = None, long_beam_pairs: list = None):
|
||||
vis = image_bgr.copy()
|
||||
for (label, conf, x1, y1, x2, y2), shape in pairs:
|
||||
cls_id = shape.get("_cls_id", 0)
|
||||
color = _CLS_COLORS[cls_id % len(_CLS_COLORS)]
|
||||
if shape.get("shape_type") == "polygon":
|
||||
pts = np.array(shape["points"], dtype=np.int32)
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 2)
|
||||
cv2.putText(vis, f"{CLASS_NAMES[cls_id]} {conf:.2f}",
|
||||
(int(x1), max(0, int(y1) - 5)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
# 레일 zone 시각화 (빨간색, 반투명)
|
||||
if rail_shapes:
|
||||
for s in rail_shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
pts = np.array(s["points"], dtype=np.int32)
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], (0, 0, 200))
|
||||
cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, (0, 0, 255), 2)
|
||||
# 긴 빔 시각화 (자홍색 bbox)
|
||||
if long_beam_pairs:
|
||||
for (label, conf, x1, y1, x2, y2), _ in long_beam_pairs:
|
||||
cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||
(255, 0, 255), 2)
|
||||
cv2.putText(vis, f"beam {conf:.2f}", (int(x1), max(0, int(y1) - 5)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
|
||||
h, w = vis.shape[:2]
|
||||
if max(h, w) > 2048:
|
||||
scale = 2048 / max(h, w)
|
||||
vis = cv2.resize(vis, (int(w * scale), int(h * scale)))
|
||||
cv2.imencode(vis_path.suffix, vis)[1].tofile(str(vis_path))
|
||||
|
||||
|
||||
# ── 타일 검출 ─────────────────────────────────────────────────────────────────
|
||||
def _detect_tiled(image_bgr: np.ndarray, prompt: str, conf: float,
|
||||
tile_size: int, overlap: float) -> list:
|
||||
H, W = image_bgr.shape[:2]
|
||||
step = max(1, int(tile_size * (1 - overlap)))
|
||||
tile_coords = []
|
||||
for y0 in range(0, H, step):
|
||||
for x0 in range(0, W, step):
|
||||
tile_coords.append((x0, y0, min(x0 + tile_size, W), min(y0 + tile_size, H)))
|
||||
print(f" 총 {len(tile_coords)}개 타일")
|
||||
all_pairs = []
|
||||
for idx, (tx0, ty0, tx1, ty1) in enumerate(tile_coords, 1):
|
||||
tile_bgr = image_bgr[ty0:ty1, tx0:tx1]
|
||||
shapes = sam3_text_segment(tile_bgr, prompt, conf)
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x + tx0, y + ty0] for x, y in s["points"]]
|
||||
all_pairs.extend(shapes_to_pairs(shapes))
|
||||
if idx % 5 == 0:
|
||||
print(f" 타일 {idx}/{len(tile_coords)}, 누적 {len(all_pairs)}개")
|
||||
return nms(all_pairs)
|
||||
|
||||
|
||||
# ── 이미지 처리 ───────────────────────────────────────────────────────────────
|
||||
def process_image(image_path: Path,
|
||||
out_label_dir: Path, out_vis_dir: Path,
|
||||
sam3_conf: float,
|
||||
tile_size: int, tile_overlap: float,
|
||||
min_pole_h: int = 160,
|
||||
rail_margin: int = 300,
|
||||
rail_gap_close: int = 500,
|
||||
beam_min_width: int = 400,
|
||||
beam_ratio: float = 2.0,
|
||||
beam_downward: int = 400,
|
||||
rail_lateral: float = 0.0,
|
||||
raw_polygons: bool = False) -> int:
|
||||
|
||||
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
print(f" 이미지 로드 실패: {image_path}")
|
||||
return 0
|
||||
H, W = image_bgr.shape[:2]
|
||||
|
||||
pole_prompt, rail_prompt, bracket_prompt = _get_prompts()
|
||||
use_tile = tile_size > 0 and (W > tile_size or H > tile_size)
|
||||
|
||||
# A: 전철주 검출
|
||||
if use_tile:
|
||||
print(f" [A] 타일 {tile_size}px: '{pole_prompt}'")
|
||||
pole_pairs = _detect_tiled(image_bgr, pole_prompt, sam3_conf, tile_size, tile_overlap)
|
||||
else:
|
||||
print(f" [A] 전철주 검출: '{pole_prompt}'")
|
||||
pole_pairs = nms(shapes_to_pairs(sam3_text_segment(image_bgr, pole_prompt, sam3_conf)))
|
||||
|
||||
# A 높이 필터 (> min_pole_h)
|
||||
before = len(pole_pairs)
|
||||
pole_pairs = filter_by_size(pole_pairs, min_h=min_pole_h, debug=True)
|
||||
print(f" [A] 높이 필터(h>={min_pole_h}): {before}→{len(pole_pairs)}개")
|
||||
|
||||
for _, shape in pole_pairs:
|
||||
shape["_cls_id"] = 0
|
||||
|
||||
# B: 철로 검출
|
||||
print(f" [B] 철로 검출: '{rail_prompt}'")
|
||||
rail_shapes = [s for s in sam3_text_segment(image_bgr, rail_prompt, sam3_conf)
|
||||
if s.get("shape_type") == "polygon"]
|
||||
|
||||
# [C] Beam/Arm 쿼리는 사용하지 않음 (검출되는 "긴 빔"이 실제로는 레일 일부였음)
|
||||
|
||||
# A ∩ Zone 필터 (Zone = rail + 측면 확장)
|
||||
if rail_margin >= 0:
|
||||
if rail_shapes:
|
||||
zone_mask = build_zone_mask(rail_shapes, H, W, rail_margin, rail_gap_close,
|
||||
lateral_expand_ratio=rail_lateral)
|
||||
if rail_lateral > 0:
|
||||
print(f" [Zone] rail + 측면 {rail_lateral*100:.0f}%")
|
||||
before = len(pole_pairs)
|
||||
kept = []
|
||||
for d, s in pole_pairs:
|
||||
if pole_in_zone(d, zone_mask):
|
||||
kept.append((d, s))
|
||||
else:
|
||||
label, conf, x1, y1, x2, y2 = d
|
||||
print(f" DROP zone [{label} {conf:.2f}] bbox=({x1:.0f},{y1:.0f})-({x2:.0f},{y2:.0f})")
|
||||
pole_pairs = kept
|
||||
print(f" [A∩Zone] 필터(+{rail_margin}px): {before}→{len(pole_pairs)}개")
|
||||
else:
|
||||
print(f" ! [B] 철로 미검출 → 전체 제외")
|
||||
pole_pairs = []
|
||||
|
||||
# 후처리: raw_polygons=True 면 SAM 원본 폴리곤만 cls=0으로 그대로 저장
|
||||
if raw_polygons:
|
||||
all_pairs = []
|
||||
for det, shape in pole_pairs:
|
||||
shape["_cls_id"] = 0
|
||||
shape["label"] = "catenary_pole"
|
||||
all_pairs.append((det, shape))
|
||||
print(f" [Raw] SAM 폴리곤 {len(all_pairs)}개 그대로 저장 (split 미적용)")
|
||||
else:
|
||||
# skeleton split (column / beam)
|
||||
column_polys, beam_polys = skeleton_split_columns_beams(pole_pairs, H, W)
|
||||
column_pairs = []
|
||||
arm_pairs = []
|
||||
for poly in column_polys:
|
||||
xs = [p[0] for p in poly]
|
||||
ys = [p[1] for p in poly]
|
||||
det = ("catenary_pole", 1.0, min(xs), min(ys), max(xs), max(ys))
|
||||
shape = {"shape_type": "polygon", "points": poly,
|
||||
"_cls_id": 0, "label": "catenary_pole"}
|
||||
column_pairs.append((det, shape))
|
||||
for poly in beam_polys:
|
||||
xs = [p[0] for p in poly]
|
||||
ys = [p[1] for p in poly]
|
||||
det = ("bracket", 1.0, min(xs), min(ys), max(xs), max(ys))
|
||||
shape = {"shape_type": "polygon", "points": poly,
|
||||
"_cls_id": 1, "label": "bracket"}
|
||||
arm_pairs.append((det, shape))
|
||||
print(f" [Split] column {len(column_pairs)}개, beam {len(arm_pairs)}개")
|
||||
all_pairs = column_pairs + arm_pairs
|
||||
if not all_pairs:
|
||||
return 0
|
||||
|
||||
save_yolo_label(out_label_dir / (image_path.stem + ".txt"), all_pairs, W, H)
|
||||
save_vis(out_vis_dir / (image_path.stem + "_vis.jpg"), image_bgr, all_pairs,
|
||||
rail_shapes, None)
|
||||
return len(all_pairs)
|
||||
|
||||
|
||||
# ── classes.txt 저장 ──────────────────────────────────────────────────────────
|
||||
def save_classes(out_dir: Path):
|
||||
(out_dir / "classes.txt").write_text("\n".join(CLASS_NAMES), encoding="utf-8")
|
||||
|
||||
|
||||
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SAM 3.1 텍스트 프롬프트 자동 라벨링")
|
||||
parser.add_argument("--input", required=True, help="이미지 파일 또는 디렉터리")
|
||||
parser.add_argument("--output", default="output/autolabel")
|
||||
parser.add_argument("--sam3_conf", type=float, default=0.20)
|
||||
parser.add_argument("--ext", default=".jpg,.jpeg,.png,.tif")
|
||||
parser.add_argument("--tile_size", type=int, default=1280,
|
||||
help="타일 크기 px (0=타일 없음, 기본:1280)")
|
||||
parser.add_argument("--tile_overlap", type=float, default=0.2)
|
||||
parser.add_argument("--min_pole_h", type=int, default=160,
|
||||
help="전철주 최소 높이 px (기본:160, 드론 거리 변화 마진 포함)")
|
||||
parser.add_argument("--rail_margin", type=int, default=300,
|
||||
help="철로 mask 팽창 거리 px (-1=비활성, 기본:300)")
|
||||
parser.add_argument("--rail_gap_close", type=int, default=500,
|
||||
help="beam 차폐 갭 보정 수평 closing 폭 px (0=비활성, 기본:500)")
|
||||
parser.add_argument("--beam_min_width", type=int, default=400,
|
||||
help="긴 빔 최소 폭 px - 가로등/짧은 arm 차단 (기본:400, 역사형만)")
|
||||
parser.add_argument("--beam_ratio", type=float, default=2.0,
|
||||
help="긴 빔 폭/높이 최소 비율 (기본:2.0)")
|
||||
parser.add_argument("--beam_downward", type=int, default=400,
|
||||
help="빔 zone 아래 확장 px (기본:400)")
|
||||
parser.add_argument("--rail_lateral", type=float, default=0.0,
|
||||
help="rail zone 측면(좌우) 확장 비율 (이미지 폭 기준, 예:0.10=10%%)")
|
||||
parser.add_argument("--raw_polygons", action="store_true",
|
||||
help="후처리 skeleton split 없이 SAM 원본 폴리곤만 저장")
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
out_dir = Path(args.output)
|
||||
out_label = out_dir / "labels"
|
||||
out_vis = out_dir / "vis"
|
||||
out_label.mkdir(parents=True, exist_ok=True)
|
||||
out_vis.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
exts = [e.strip().lower() for e in args.ext.split(",")]
|
||||
if input_path.is_dir():
|
||||
images = sorted([p for p in input_path.iterdir() if p.suffix.lower() in exts],
|
||||
key=lambda p: p.name)
|
||||
else:
|
||||
images = [input_path]
|
||||
|
||||
pole_prompt, rail_prompt, bracket_prompt = _get_prompts()
|
||||
print(f"처리 대상: {len(images)}장")
|
||||
print(f"전철주 프롬프트: {pole_prompt}")
|
||||
print(f"철로 프롬프트: {rail_prompt}")
|
||||
print(f"타일: {args.tile_size}px 겹침: {args.tile_overlap*100:.0f}%")
|
||||
print(f"출력: {out_dir}\n")
|
||||
save_classes(out_dir)
|
||||
|
||||
total = 0
|
||||
for i, img_path in enumerate(images, 1):
|
||||
print(f"[{i}/{len(images)}] {img_path.name}")
|
||||
total += process_image(img_path, out_label, out_vis,
|
||||
args.sam3_conf,
|
||||
args.tile_size, args.tile_overlap,
|
||||
args.min_pole_h, args.rail_margin, args.rail_gap_close,
|
||||
args.beam_min_width, args.beam_ratio, args.beam_downward,
|
||||
args.rail_lateral, args.raw_polygons)
|
||||
|
||||
print(f"\n완료: 총 {total}개 라벨 생성")
|
||||
print(f"라벨: {out_label}")
|
||||
print(f"시각화: {out_vis}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
278
tools/sam3_batch_label.py
Normal file
278
tools/sam3_batch_label.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
SAM3 배치 자동 레이블링 파이프라인
|
||||
===================================
|
||||
SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 →
|
||||
X-AnyLabeling 호환 JSON annotation 파일 자동 생성
|
||||
|
||||
사용법:
|
||||
# 기본 (sample/rail 폴더, 서버 localhost:8000)
|
||||
python tools/sam3_batch_label.py
|
||||
|
||||
# 폴더 지정
|
||||
python tools/sam3_batch_label.py --input sample/rail --output output/labels
|
||||
|
||||
# conf 조정 (낮출수록 더 많이 검출, 오탐도 증가)
|
||||
python tools/sam3_batch_label.py --conf 0.20
|
||||
|
||||
SAM3 서버 실행:
|
||||
cd X-AnyLabeling-Server
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
MODEL_ID = "segment_anything_3"
|
||||
|
||||
# 검출 대상 + 한국어 레이블 매핑
|
||||
TARGETS = {
|
||||
"pole": "전철주_세로",
|
||||
"catenary arm": "전철주_가로",
|
||||
"junction box": "통신박스",
|
||||
"electrical box": "전기박스",
|
||||
"fence": "펜스",
|
||||
}
|
||||
|
||||
# 시각화 색상 (BGR)
|
||||
COLORS = {
|
||||
"전철주_세로": (0, 200, 255),
|
||||
"전철주_가로": (0, 100, 255),
|
||||
"통신박스": (255, 180, 0),
|
||||
"전기박스": (100, 255, 200),
|
||||
"펜스": (0, 255, 100),
|
||||
}
|
||||
|
||||
|
||||
def encode_image(image_bgr: np.ndarray) -> str:
|
||||
_, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도
|
||||
return base64.b64encode(buf).decode("utf-8")
|
||||
|
||||
|
||||
def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list:
|
||||
"""SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환."""
|
||||
payload = {
|
||||
"model": MODEL_ID,
|
||||
"image": encode_image(image_bgr),
|
||||
"params": {
|
||||
"text_prompt": text_prompt,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
"conf_threshold": conf,
|
||||
"epsilon_factor": 0.002,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if data.get("success"):
|
||||
return data.get("data", {}).get("shapes", [])
|
||||
except Exception as e:
|
||||
print(f" [ERROR] SAM3 호출 실패: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict:
|
||||
"""이미지 1장: 모든 클래스 검출 → annotation dict 반환."""
|
||||
# 한글 경로 대응: np.fromfile + imdecode
|
||||
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
return {}
|
||||
h, w = image_bgr.shape[:2]
|
||||
|
||||
all_shapes = []
|
||||
class_counts = {}
|
||||
|
||||
for eng_label, kor_label in TARGETS.items():
|
||||
shapes = sam3_text_predict(image_bgr, eng_label, conf)
|
||||
# label 필드를 한국어로 교체
|
||||
for s in shapes:
|
||||
s["label"] = kor_label
|
||||
all_shapes.extend(shapes)
|
||||
if shapes:
|
||||
class_counts[kor_label] = len(shapes)
|
||||
|
||||
# X-AnyLabeling JSON 형식
|
||||
annotation = {
|
||||
"version": "3.3.9",
|
||||
"flags": {},
|
||||
"shapes": [
|
||||
{
|
||||
"label": s["label"],
|
||||
"points": s["points"],
|
||||
"group_id": None,
|
||||
"shape_type": s.get("shape_type", "polygon"),
|
||||
"flags": {},
|
||||
"score": round(float(s.get("score", 0)), 4),
|
||||
}
|
||||
for s in all_shapes
|
||||
],
|
||||
"imagePath": image_path.name,
|
||||
"imageData": None,
|
||||
"imageHeight": h,
|
||||
"imageWidth": w,
|
||||
}
|
||||
|
||||
# 시각화
|
||||
if vis_dir is not None:
|
||||
vis = draw_vis(image_bgr, all_shapes)
|
||||
vis_path = vis_dir / f"{image_path.stem}_vis.jpg"
|
||||
cv2.imencode(".jpg", vis)[1].tofile(str(vis_path))
|
||||
|
||||
return annotation, class_counts
|
||||
|
||||
|
||||
def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray:
|
||||
vis = image_bgr.copy()
|
||||
overlay = image_bgr.copy()
|
||||
|
||||
for shape in shapes:
|
||||
pts = np.array(shape["points"], dtype=np.int32)
|
||||
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
|
||||
pts = pts[:-1]
|
||||
label = shape.get("label", "unknown")
|
||||
color = COLORS.get(label, (200, 200, 200))
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.polylines(vis, [pts], True, color, 2)
|
||||
|
||||
# 레이블 텍스트
|
||||
cx = int(pts[:, 0].mean())
|
||||
cy = int(pts[:, 1].mean())
|
||||
cv2.putText(vis, label, (cx - 30, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
|
||||
|
||||
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
|
||||
|
||||
# 범례
|
||||
y = 25
|
||||
for kor, color in COLORS.items():
|
||||
cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1)
|
||||
cv2.putText(vis, kor, (28, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
y += 20
|
||||
|
||||
return vis
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", default="sample/rail")
|
||||
parser.add_argument("--output", default="output/sam3_labels",
|
||||
help="JSON annotation 저장 폴더")
|
||||
parser.add_argument("--vis", default="output/sam3_vis",
|
||||
help="시각화 이미지 저장 폴더 ('' 로 비활성화)")
|
||||
parser.add_argument("--conf", type=float, default=0.20,
|
||||
help="SAM3 confidence threshold (기본 0.20)")
|
||||
parser.add_argument("--classes", nargs="+",
|
||||
help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm")
|
||||
parser.add_argument("--server", default="http://localhost:8000")
|
||||
args = parser.parse_args()
|
||||
|
||||
global SAM3_SERVER, TARGETS
|
||||
SAM3_SERVER = args.server
|
||||
|
||||
if args.classes:
|
||||
key_map = {
|
||||
"catenary_pole": ("catenary pole", "전철주_세로"),
|
||||
"concrete_pole": ("concrete pole", "전철주_세로"),
|
||||
"catenary_arm": ("catenary arm", "전철주_가로"),
|
||||
"junction_box": ("junction box", "통신박스"),
|
||||
"electrical_box": ("electrical box", "전기박스"),
|
||||
"fence": ("fence", "펜스"),
|
||||
}
|
||||
TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map}
|
||||
|
||||
# 서버 확인
|
||||
try:
|
||||
r = requests.get(f"{SAM3_SERVER}/health", timeout=5)
|
||||
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
|
||||
except Exception:
|
||||
print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}")
|
||||
print(" 서버를 먼저 실행하세요:")
|
||||
print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000")
|
||||
sys.exit(1)
|
||||
|
||||
# 입력
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
images = [input_path]
|
||||
elif input_path.is_dir():
|
||||
images = sorted(
|
||||
list(input_path.glob("*.jpg")) +
|
||||
list(input_path.glob("*.jpeg")) +
|
||||
list(input_path.glob("*.png"))
|
||||
)
|
||||
else:
|
||||
print(f"[ERROR] 경로 없음: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not images:
|
||||
print(f"[ERROR] 이미지 없음: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
out_dir = Path(args.output)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
vis_dir = None
|
||||
if args.vis:
|
||||
vis_dir = Path(args.vis)
|
||||
vis_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n처리 대상: {len(images)}장")
|
||||
print(f"검출 클래스: {list(TARGETS.values())}")
|
||||
print(f"SAM3 conf: {args.conf}")
|
||||
print(f"annotation 저장: {out_dir}")
|
||||
if vis_dir:
|
||||
print(f"시각화 저장: {vis_dir}")
|
||||
print()
|
||||
|
||||
total_counts: dict = {v: 0 for v in TARGETS.values()}
|
||||
processed = 0
|
||||
|
||||
for img_path in images:
|
||||
print(f"[{processed+1}/{len(images)}] {img_path.name}")
|
||||
result = process_image(img_path, args.conf, vis_dir)
|
||||
if not result:
|
||||
print(f" [SKIP] 처리 실패")
|
||||
continue
|
||||
|
||||
annotation, class_counts = result
|
||||
n = len(annotation["shapes"])
|
||||
print(f" → {n}개 객체 검출: {class_counts if class_counts else '없음'}")
|
||||
|
||||
# JSON 저장
|
||||
json_path = out_dir / f"{img_path.stem}.json"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotation, f, ensure_ascii=False, indent=2)
|
||||
|
||||
for k, v in class_counts.items():
|
||||
total_counts[k] = total_counts.get(k, 0) + v
|
||||
processed += 1
|
||||
|
||||
# 요약
|
||||
print("\n" + "="*50)
|
||||
print(f"완료: {processed}/{len(images)}장 처리")
|
||||
print("\n클래스별 총 검출 수:")
|
||||
for cls, cnt in total_counts.items():
|
||||
avg = cnt / max(processed, 1)
|
||||
bar = "#" * min(cnt, 30)
|
||||
print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}")
|
||||
print(f"\nJSON 저장: {out_dir.resolve()}")
|
||||
if vis_dir:
|
||||
print(f"시각화: {vis_dir.resolve()}")
|
||||
print("\n다음 단계:")
|
||||
print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택")
|
||||
print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
291
tools/sam3_everything_explore.py
Normal file
291
tools/sam3_everything_explore.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
SAM3.1 탐색 모드 (Discovery Sweep)
|
||||
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
|
||||
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
|
||||
|
||||
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
|
||||
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
|
||||
|
||||
사용법:
|
||||
python tools/sam3_everything_explore.py \\
|
||||
--input "data/역사구간/1.회덕역/..." \\
|
||||
--cols 8 --rows 6
|
||||
|
||||
사전 조건: SAM3.1 서버 실행 (start_server.bat)
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
SAM3_MODEL_ID = "segment_anything_3"
|
||||
|
||||
|
||||
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
|
||||
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
|
||||
h, w = image_bgr.shape[:2]
|
||||
scale = 1.0
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
|
||||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
return base64.b64encode(buf).decode("utf-8"), scale
|
||||
|
||||
|
||||
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
|
||||
DISCOVERY_PROMPT = (
|
||||
"railroad track, railway rail, "
|
||||
"catenary pole, overhead line pole, electric pole, "
|
||||
"overhead wire, catenary wire, power line cable, "
|
||||
"railway sleeper, concrete tie, "
|
||||
"guardrail, highway barrier, road fence, "
|
||||
"bridge, viaduct, overpass, "
|
||||
"vegetation, tree, bush, grass, "
|
||||
"building, structure, roof, wall, "
|
||||
"vehicle, car, truck, "
|
||||
"road, asphalt, pavement, "
|
||||
"slope, embankment, retaining wall, "
|
||||
"noise barrier, sound wall, "
|
||||
"signal, sign board"
|
||||
)
|
||||
|
||||
|
||||
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
|
||||
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
|
||||
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
|
||||
b64, scale = encode_image(tile_bgr)
|
||||
payload = {
|
||||
"model": SAM3_MODEL_ID,
|
||||
"image": b64,
|
||||
"params": {
|
||||
"text_prompt": prompt,
|
||||
"conf_threshold": conf,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
},
|
||||
}
|
||||
try:
|
||||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if not resp.get("success"):
|
||||
return []
|
||||
shapes = resp.get("data", {}).get("shapes", [])
|
||||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||||
if scale < 1.0:
|
||||
inv = 1.0 / scale
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||||
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
||||
except Exception as e:
|
||||
print(f" [SAM3 오류] {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ── NMS ───────────────────────────────────────────────────────────────────────
|
||||
def _bbox(pts):
|
||||
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
|
||||
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
|
||||
if not shapes:
|
||||
return []
|
||||
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
|
||||
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while len(order):
|
||||
i = order[0]; keep.append(i)
|
||||
if len(order) == 1: break
|
||||
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
|
||||
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
|
||||
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
|
||||
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
|
||||
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
|
||||
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
|
||||
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
|
||||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||||
order = order[1:][iou < iou_thresh]
|
||||
return [shapes[i] for i in keep]
|
||||
|
||||
|
||||
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
|
||||
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt):
|
||||
H, W = image_bgr.shape[:2]
|
||||
base_w = W / cols
|
||||
base_h = H / rows
|
||||
pad_x = int(base_w * overlap)
|
||||
pad_y = int(base_h * overlap)
|
||||
|
||||
tiles = []
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
idx = r * cols + c + 1
|
||||
x0 = max(0, int(c * base_w) - pad_x)
|
||||
x1 = min(W, int((c + 1) * base_w) + pad_x)
|
||||
y0 = max(0, int(r * base_h) - pad_y)
|
||||
y1 = min(H, int((r + 1) * base_h) + pad_y)
|
||||
tiles.append((idx, x0, y0, x1, y1))
|
||||
|
||||
total = len(tiles)
|
||||
done = [0]
|
||||
all_shapes = []
|
||||
|
||||
def process(args):
|
||||
idx, x0, y0, x1, y1 = args
|
||||
tile = image_bgr[y0:y1, x0:x1]
|
||||
shapes = sam3_everything(tile, conf, prompt)
|
||||
for s in shapes:
|
||||
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
|
||||
return shapes
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
futs = {ex.submit(process, t): t for t in tiles}
|
||||
for fut in as_completed(futs):
|
||||
result = fut.result()
|
||||
all_shapes.extend(result)
|
||||
done[0] += 1
|
||||
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r")
|
||||
print()
|
||||
return all_shapes
|
||||
|
||||
|
||||
# ── 시각화 ────────────────────────────────────────────────────────────────────
|
||||
def draw_everything(image_bgr, shapes, cols, rows):
|
||||
vis = image_bgr.copy()
|
||||
H, W = vis.shape[:2]
|
||||
|
||||
# 타일 경계
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
bx0, by0 = int(c * W / cols), int(r * H / rows)
|
||||
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
|
||||
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
for s in shapes:
|
||||
pts = np.array(s["points"], dtype=np.int32)
|
||||
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 1)
|
||||
|
||||
# 라벨 표시 (있을 경우)
|
||||
label = s.get("label", "")
|
||||
if label:
|
||||
cx = int(np.mean([p[0] for p in s["points"]]))
|
||||
cy = int(np.mean([p[1] for p in s["points"]]))
|
||||
cv2.putText(vis, label, (cx, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
|
||||
|
||||
cv2.putText(vis, f"total segments: {len(shapes)}",
|
||||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
|
||||
return vis
|
||||
|
||||
|
||||
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
|
||||
def analyze_labels(shapes):
|
||||
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
|
||||
if not labels:
|
||||
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
|
||||
return
|
||||
|
||||
counter = Counter(labels)
|
||||
print(f"\n{'─'*50}")
|
||||
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
|
||||
print(f"{'─'*50}")
|
||||
for label, cnt in counter.most_common(30):
|
||||
bar = "#" * min(cnt, 40)
|
||||
print(f" {label:35s} {cnt:4d} {bar}")
|
||||
print(f"{'─'*50}")
|
||||
|
||||
top_labels = [lb for lb, _ in counter.most_common(10)]
|
||||
print(f"\n[text_prompt 후보]")
|
||||
print(f' "{", ".join(top_labels)}"')
|
||||
|
||||
|
||||
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
ap.add_argument("--input", required=True, help="입력 이미지")
|
||||
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
|
||||
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
|
||||
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
|
||||
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
|
||||
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
|
||||
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
|
||||
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
|
||||
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
|
||||
args = ap.parse_args()
|
||||
|
||||
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
|
||||
|
||||
img_path = Path(args.input)
|
||||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
print(f"이미지 로드 실패: {img_path}"); return
|
||||
|
||||
H, W = image_bgr.shape[:2]
|
||||
print(f"이미지 : {W}×{H}")
|
||||
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}개")
|
||||
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
|
||||
|
||||
import time
|
||||
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
|
||||
for item in prompt.split(","):
|
||||
print(f" · {item.strip()}")
|
||||
print()
|
||||
|
||||
t0 = time.time()
|
||||
shapes = detect_everything_tiled(
|
||||
image_bgr, args.cols, args.rows, args.overlap,
|
||||
args.conf, args.workers, prompt
|
||||
)
|
||||
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
|
||||
shapes = nms_shapes(shapes, iou_thresh=args.nms)
|
||||
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
|
||||
|
||||
analyze_labels(shapes)
|
||||
|
||||
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
|
||||
h, w = vis.shape[:2]
|
||||
if max(h, w) > 4096:
|
||||
s = 4096 / max(h, w)
|
||||
vis = cv2.resize(vis, (int(w*s), int(h*s)))
|
||||
|
||||
out_path = (Path(args.output) if args.output
|
||||
else img_path.parent / (img_path.stem + "_everything.jpg"))
|
||||
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
|
||||
print(f"\n저장: {out_path}")
|
||||
|
||||
# JSON으로 라벨 데이터도 저장 (분석용)
|
||||
json_path = out_path.with_suffix(".json")
|
||||
label_data = {
|
||||
"total_segments": len(shapes),
|
||||
"label_counts": dict(Counter(
|
||||
s.get("label", "(no label)") for s in shapes
|
||||
)),
|
||||
"segments": [
|
||||
{"label": s.get("label",""), "score": s.get("score",0),
|
||||
"bbox": list(_bbox(s["points"]))}
|
||||
for s in shapes
|
||||
]
|
||||
}
|
||||
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"라벨 데이터: {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
219
tools/sam3_receipt_ocr.py
Normal file
219
tools/sam3_receipt_ocr.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# Add server to path so we can import sam3 locally
|
||||
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
|
||||
models_path = os.path.join(server_path, "app", "models")
|
||||
if server_path not in sys.path:
|
||||
sys.path.insert(0, server_path)
|
||||
if models_path not in sys.path:
|
||||
sys.path.insert(0, models_path)
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
|
||||
# PaddleOCR will be imported inside the function to avoid errors if not installed
|
||||
def run_paddleocr(image_np):
|
||||
# Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows
|
||||
import os as _os
|
||||
_os.environ["FLAGS_use_mkldnn"] = "0"
|
||||
_os.environ["PADDLE_WITH_MKLDNN"] = "0"
|
||||
_os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
# Force use_gpu=False if oneDNN is failing on CPU, or let it detect.
|
||||
# On Windows, sometimes CPU + oneDNN is the default and it fails.
|
||||
ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available())
|
||||
result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility
|
||||
return result
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
offset = 1.0 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
|
||||
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
|
||||
return grid
|
||||
|
||||
def mask_iou(mask1, mask2):
|
||||
inter = np.logical_and(mask1, mask2).sum()
|
||||
union = np.logical_or(mask1, mask2).sum()
|
||||
if union == 0: return 0
|
||||
return inter / union
|
||||
|
||||
def get_receipt_mask(image_bgr, model_path, points_per_side=32):
|
||||
print("Loading SAM3 Model for receipt detection...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
model = build_sam3_image_model(
|
||||
bpe_path=bpe_path,
|
||||
device=device,
|
||||
checkpoint_path=model_path,
|
||||
)
|
||||
|
||||
processor = Sam3Processor(model, confidence_threshold=0.7, device=device)
|
||||
|
||||
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
state = processor.set_image(pil_image)
|
||||
grid_points = build_point_grid(points_per_side)
|
||||
|
||||
masks = []
|
||||
scores = []
|
||||
|
||||
print(f"Sampling {len(grid_points)} points...")
|
||||
for i, (nx, ny) in enumerate(grid_points):
|
||||
processor.reset_all_prompts(state)
|
||||
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
|
||||
|
||||
if "masks" in state and len(state["masks"]) > 0:
|
||||
best_idx = torch.argmax(state["scores"])
|
||||
mask = state["masks"][best_idx].cpu().numpy()
|
||||
score = state["scores"][best_idx].item()
|
||||
if score > 0.8:
|
||||
masks.append(mask)
|
||||
scores.append(score)
|
||||
|
||||
if not masks:
|
||||
return None
|
||||
|
||||
# Pick the largest mask that covers a significant area (receipt is usually big)
|
||||
areas = [m.sum() for m in masks]
|
||||
# Simple heuristic: largest mask
|
||||
best_mask_idx = np.argmax(areas)
|
||||
return masks[best_mask_idx]
|
||||
|
||||
def crop_and_warp(image, mask):
|
||||
# Ensure mask is 2D
|
||||
mask = np.squeeze(mask)
|
||||
if mask.ndim != 2:
|
||||
print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...")
|
||||
if mask.ndim == 3: mask = mask[0]
|
||||
|
||||
# Find contours
|
||||
mask_uint8 = (mask > 0).astype(np.uint8) * 255
|
||||
if np.sum(mask_uint8) == 0:
|
||||
print("Warning: Mask is empty.")
|
||||
return image
|
||||
|
||||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return image
|
||||
|
||||
cnt = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Get bounding box
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
|
||||
# Try to find 4 corners for perspective transform
|
||||
epsilon = 0.02 * cv2.arcLength(cnt, True)
|
||||
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
||||
|
||||
if len(approx) == 4:
|
||||
print("Found 4 corners, applying perspective transform...")
|
||||
pts = approx.reshape(4, 2)
|
||||
# Sort points: top-left, top-right, bottom-right, bottom-left
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
rect[2] = pts[np.argmax(s)]
|
||||
diff = np.diff(pts, axis=1)
|
||||
rect[1] = pts[np.argmin(diff)]
|
||||
rect[3] = pts[np.argmax(diff)]
|
||||
|
||||
(tl, tr, br, bl) = rect
|
||||
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
||||
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
||||
maxWidth = max(int(widthA), int(widthB))
|
||||
|
||||
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
||||
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
||||
maxHeight = max(int(heightA), int(heightB))
|
||||
|
||||
dst = np.array([
|
||||
[0, 0],
|
||||
[maxWidth - 1, 0],
|
||||
[maxWidth - 1, maxHeight - 1],
|
||||
[0, maxHeight - 1]], dtype="float32")
|
||||
|
||||
M = cv2.getPerspectiveTransform(rect, dst)
|
||||
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
||||
return warped
|
||||
else:
|
||||
print("Could not find 4 clear corners, just cropping to bounding box.")
|
||||
# Create a mask image to black out background
|
||||
masked_img = cv2.bitwise_and(image, image, mask=mask_uint8)
|
||||
crop = masked_img[y:y+h, x:x+w]
|
||||
return crop
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True, help="Input image path")
|
||||
parser.add_argument("--output_dir", default="output/ocr", help="Output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Read image
|
||||
buf = np.fromfile(args.input, dtype=np.uint8)
|
||||
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
print(f"Failed to read {args.input}")
|
||||
return
|
||||
|
||||
model_path = os.path.join(server_path, "sam3.pt")
|
||||
|
||||
# 1. SAM3 Masking
|
||||
mask = get_receipt_mask(image, model_path)
|
||||
if mask is None:
|
||||
print("No receipt-like object found.")
|
||||
return
|
||||
|
||||
# 2. Cropping / Warping
|
||||
processed_img = crop_and_warp(image, mask)
|
||||
|
||||
# Save processed image for debugging
|
||||
processed_path = os.path.join(args.output_dir, "processed_receipt.jpg")
|
||||
cv2.imwrite(processed_path, processed_img)
|
||||
print(f"Saved processed image to {processed_path}")
|
||||
|
||||
# 3. PaddleOCR
|
||||
print("Running PaddleOCR...")
|
||||
ocr_results = run_paddleocr(processed_img)
|
||||
|
||||
# 4. Save results
|
||||
output_json = os.path.join(args.output_dir, "ocr_results.json")
|
||||
with open(output_json, "w", encoding="utf-8") as f:
|
||||
json.dump(ocr_results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"OCR results saved to {output_json}")
|
||||
|
||||
# Print summary
|
||||
if ocr_results:
|
||||
print("\n--- OCR Extracted Text ---")
|
||||
for page in ocr_results:
|
||||
if page is None:
|
||||
continue
|
||||
# New v3.x format: list of dicts with 'rec_text' key
|
||||
if isinstance(page, dict):
|
||||
text = page.get('rec_text', '')
|
||||
score = page.get('rec_score', 0)
|
||||
if text:
|
||||
print(f"{text} (conf: {score:.2f})")
|
||||
elif isinstance(page, list):
|
||||
for line in page:
|
||||
if line and isinstance(line, list) and len(line) >= 2:
|
||||
print(line[1][0])
|
||||
print("--------------------------")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
185
tools/sam3_segment_everything.py
Normal file
185
tools/sam3_segment_everything.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add server to path so we can import sam3 locally
|
||||
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
|
||||
models_path = os.path.join(server_path, "app", "models")
|
||||
if server_path not in sys.path:
|
||||
sys.path.insert(0, server_path)
|
||||
if models_path not in sys.path:
|
||||
sys.path.insert(0, models_path)
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from app.models.segment_anything_3 import SegmentAnything3
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
|
||||
offset = 1.0 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
|
||||
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
|
||||
return grid
|
||||
|
||||
def mask_iou(mask1, mask2):
|
||||
inter = np.logical_and(mask1, mask2).sum()
|
||||
union = np.logical_or(mask1, mask2).sum()
|
||||
if union == 0:
|
||||
return 0
|
||||
return inter / union
|
||||
|
||||
def mask_to_polygon(mask, epsilon_factor=0.001):
|
||||
mask = np.squeeze(mask)
|
||||
mask_uint8 = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return []
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
if epsilon_factor > 0:
|
||||
epsilon = epsilon_factor * cv2.arcLength(largest, True)
|
||||
approx = cv2.approxPolyDP(largest, epsilon, True)
|
||||
else:
|
||||
approx = largest
|
||||
points = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
return points
|
||||
|
||||
def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5):
|
||||
print("Loading SAM3 Model locally...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
model = build_sam3_image_model(
|
||||
bpe_path=bpe_path,
|
||||
device=device,
|
||||
checkpoint_path=model_path,
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device)
|
||||
|
||||
# PIL image format
|
||||
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
from PIL import Image
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
print("Computing image embedding...")
|
||||
t0 = time.time()
|
||||
state = processor.set_image(pil_image)
|
||||
print(f"Image embedding done in {time.time() - t0:.2f}s")
|
||||
|
||||
grid_points = build_point_grid(points_per_side)
|
||||
print(f"Generated {len(grid_points)} grid points for sampling.")
|
||||
|
||||
masks = []
|
||||
scores = []
|
||||
|
||||
t0 = time.time()
|
||||
for i, (nx, ny) in enumerate(grid_points):
|
||||
if i % 100 == 0:
|
||||
print(f" Processed {i}/{len(grid_points)} points...")
|
||||
|
||||
processor.reset_all_prompts(state)
|
||||
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
|
||||
|
||||
if "masks" in state and len(state["masks"]) > 0:
|
||||
# Take the mask with the highest score
|
||||
best_idx = torch.argmax(state["scores"])
|
||||
mask = state["masks"][best_idx].cpu().numpy()
|
||||
score = state["scores"][best_idx].item()
|
||||
if score > conf_thresh:
|
||||
masks.append(mask)
|
||||
scores.append(score)
|
||||
print(f"Grid prediction done in {time.time() - t0:.2f}s")
|
||||
print(f"Found {len(masks)} raw masks.")
|
||||
|
||||
if not masks:
|
||||
return []
|
||||
|
||||
# Simple NMS based on IoU
|
||||
print("Applying NMS...")
|
||||
order = np.argsort(scores)[::-1]
|
||||
keep = []
|
||||
|
||||
for idx in order:
|
||||
if len(keep) == 0:
|
||||
keep.append(idx)
|
||||
continue
|
||||
|
||||
current_mask = masks[idx]
|
||||
overlap = False
|
||||
for k in keep:
|
||||
iou = mask_iou(current_mask, masks[k])
|
||||
if iou > nms_thresh:
|
||||
overlap = True
|
||||
break
|
||||
if not overlap:
|
||||
keep.append(idx)
|
||||
|
||||
final_masks = [masks[idx] for idx in keep]
|
||||
final_scores = [scores[idx] for idx in keep]
|
||||
print(f"Kept {len(final_masks)} masks after NMS.")
|
||||
|
||||
# Convert to polygons
|
||||
results = []
|
||||
for m, s in zip(final_masks, final_scores):
|
||||
poly = mask_to_polygon(m)
|
||||
if poly:
|
||||
results.append({"polygon": poly, "score": s})
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True, help="Input image path")
|
||||
parser.add_argument("--output", required=True, help="Output vis image path")
|
||||
parser.add_argument("--points", type=int, default=32, help="Points per side")
|
||||
args = parser.parse_args()
|
||||
|
||||
buf = np.fromfile(args.input, dtype=np.uint8)
|
||||
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
print(f"Failed to read {args.input}")
|
||||
return
|
||||
|
||||
# Shrink image if too large just to make NMS faster
|
||||
h, w = image.shape[:2]
|
||||
max_dim = 1024
|
||||
if max(h, w) > max_dim:
|
||||
scale = max_dim / max(h, w)
|
||||
image_proc = cv2.resize(image, (int(w * scale), int(h * scale)))
|
||||
else:
|
||||
image_proc = image.copy()
|
||||
|
||||
model_path = os.path.join(server_path, "sam3.pt")
|
||||
|
||||
results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7)
|
||||
|
||||
vis = image_proc.copy()
|
||||
np.random.seed(42)
|
||||
for res in results:
|
||||
poly = res["polygon"]
|
||||
pts = np.array(poly, dtype=np.int32)
|
||||
color = np.random.randint(0, 255, (3,)).tolist()
|
||||
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 1)
|
||||
|
||||
# Fix unicode paths in output
|
||||
is_success, im_buf_arr = cv2.imencode(".jpg", vis)
|
||||
if is_success:
|
||||
im_buf_arr.tofile(args.output)
|
||||
print(f"Saved visualization to {args.output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
tools/show_tiles.py
Normal file
72
tools/show_tiles.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
이미지를 cols×rows 타일로 분할하고 번호를 좌상단에 표시.
|
||||
사용법:
|
||||
python tools/show_tiles.py --input <이미지> --cols 8 --rows 6 --overlap 0.10
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--input", required=True)
|
||||
ap.add_argument("--output", default=None)
|
||||
ap.add_argument("--cols", type=int, default=8)
|
||||
ap.add_argument("--rows", type=int, default=6)
|
||||
ap.add_argument("--overlap", type=float, default=0.10)
|
||||
args = ap.parse_args()
|
||||
|
||||
img_path = Path(args.input)
|
||||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
H, W = img.shape[:2]
|
||||
|
||||
base_w = W / args.cols
|
||||
base_h = H / args.rows
|
||||
pad_x = int(base_w * args.overlap)
|
||||
pad_y = int(base_h * args.overlap)
|
||||
|
||||
vis = img.copy()
|
||||
|
||||
for r in range(args.rows):
|
||||
for c in range(args.cols):
|
||||
idx = r * args.cols + c + 1
|
||||
x0 = max(0, int(c * base_w) - pad_x)
|
||||
x1 = min(W, int((c + 1) * base_w) + pad_x)
|
||||
y0 = max(0, int(r * base_h) - pad_y)
|
||||
y1 = min(H, int((r + 1) * base_h) + pad_y)
|
||||
|
||||
# 타일 경계선 (기준선 = 중복 없는 경계)
|
||||
bx0 = int(c * base_w)
|
||||
by0 = int(r * base_h)
|
||||
bx1 = min(W, int((c + 1) * base_w))
|
||||
by1 = min(H, int((r + 1) * base_h))
|
||||
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (0, 200, 255), 3)
|
||||
|
||||
# 번호 좌상단
|
||||
font_scale = max(0.8, min(W, H) / 3000)
|
||||
thickness = max(2, int(font_scale * 3))
|
||||
tx, ty = bx0 + 10, by0 + int(base_h * 0.12)
|
||||
|
||||
# 배경 박스
|
||||
label = str(idx)
|
||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, thickness)
|
||||
cv2.rectangle(vis, (tx - 4, ty - th - 4), (tx + tw + 4, ty + 4), (0, 0, 0), -1)
|
||||
cv2.putText(vis, label, (tx, ty),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, (0, 200, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
# 출력 크기 제한
|
||||
h, w = vis.shape[:2]
|
||||
if max(h, w) > 4096:
|
||||
s = 4096 / max(h, w)
|
||||
vis = cv2.resize(vis, (int(w * s), int(h * s)))
|
||||
|
||||
out = (Path(args.output) if args.output
|
||||
else img_path.parent / (img_path.stem + "_tiles.jpg"))
|
||||
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 92])[1].tofile(str(out))
|
||||
print(f"저장: {out} ({args.cols}×{args.rows}={args.cols*args.rows}타일)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
237
tools/video_sam3_segment.py
Normal file
237
tools/video_sam3_segment.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션
|
||||
Usage:
|
||||
1. SAM3 서버 시작: start_server.bat
|
||||
2. python tools/video_sam3_segment.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# === 설정 ===
|
||||
VIDEO_PATH = Path("sample/rail.mp4")
|
||||
OUTPUT_DIR = Path("output/video_segmentation")
|
||||
SAM3_URL = "http://localhost:8000"
|
||||
MODEL_ID = "segment_anything_3"
|
||||
FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격
|
||||
|
||||
# 철도 시설물 + 일반 객체 프롬프트
|
||||
PROMPTS = [
|
||||
"catenary pole", # 전철주
|
||||
"junction box", # 전기박스
|
||||
"utility box", # 통신박스
|
||||
"rail track", # 레일
|
||||
"fence", # 펜스
|
||||
"cable", # 전선/케이블
|
||||
"sign", # 표지판
|
||||
"building", # 건물
|
||||
"vegetation", # 식생
|
||||
]
|
||||
|
||||
# 클래스별 색상 (BGR)
|
||||
COLORS = {
|
||||
"catenary pole": (0, 0, 255), # 빨강
|
||||
"junction box": (0, 165, 255), # 주황
|
||||
"utility box": (0, 255, 255), # 노랑
|
||||
"rail track": (255, 0, 0), # 파랑
|
||||
"fence": (255, 0, 255), # 자홍
|
||||
"cable": (0, 255, 0), # 초록
|
||||
"sign": (255, 255, 0), # 시안
|
||||
"building": (128, 128, 255), # 연한 빨강
|
||||
"vegetation": (0, 128, 0), # 진한 초록
|
||||
}
|
||||
|
||||
|
||||
def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]:
|
||||
"""동영상에서 일정 간격으로 프레임 추출"""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
print(f"ERROR: Cannot open video: {video_path}")
|
||||
sys.exit(1)
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s")
|
||||
|
||||
frames = []
|
||||
idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if idx % interval == 0:
|
||||
frames.append((idx, frame))
|
||||
idx += 1
|
||||
|
||||
cap.release()
|
||||
print(f"Extracted {len(frames)} frames (interval={interval})")
|
||||
return frames
|
||||
|
||||
|
||||
def encode_frame(frame: np.ndarray) -> str:
|
||||
"""프레임을 base64로 인코딩"""
|
||||
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
def predict_sam3(image_b64: str, text_prompt: str) -> dict:
|
||||
"""SAM3 서버에 예측 요청"""
|
||||
payload = {
|
||||
"model": MODEL_ID,
|
||||
"image": image_b64,
|
||||
"params": {
|
||||
"text_prompt": text_prompt,
|
||||
"conf_threshold": 0.25,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f" ERROR [{text_prompt}]: {e}")
|
||||
return {"success": False}
|
||||
|
||||
|
||||
def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray:
|
||||
"""세그멘테이션 결과를 프레임 위에 그리기"""
|
||||
overlay = frame.copy()
|
||||
color = COLORS.get(prompt, (200, 200, 200))
|
||||
|
||||
for shape in shapes:
|
||||
points = np.array(shape["points"], dtype=np.int32)
|
||||
shape_type = shape.get("shape_type", "polygon")
|
||||
|
||||
if shape_type == "polygon" and len(points) >= 3:
|
||||
cv2.fillPoly(overlay, [points], color)
|
||||
cv2.polylines(frame, [points], True, color, 2)
|
||||
elif shape_type == "rectangle" and len(points) == 2:
|
||||
cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1)
|
||||
cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2)
|
||||
|
||||
# 라벨 텍스트
|
||||
label = shape.get("label", prompt)
|
||||
score = shape.get("score")
|
||||
text = f"{label}" + (f" {score:.2f}" if score else "")
|
||||
if len(points) > 0:
|
||||
tx, ty = int(points[0][0]), int(points[0][1]) - 5
|
||||
cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
# 반투명 오버레이 블렌딩
|
||||
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
||||
return frame
|
||||
|
||||
|
||||
def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray:
|
||||
"""클래스별 개별 레이어 이미지 (투명 배경 위 마스크)"""
|
||||
h, w = frame_shape[:2]
|
||||
color = COLORS.get(prompt, (200, 200, 200))
|
||||
bgr = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
alpha = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for shape in shapes:
|
||||
points = np.array(shape["points"], dtype=np.int32)
|
||||
if shape.get("shape_type") == "polygon" and len(points) >= 3:
|
||||
cv2.fillPoly(bgr, [points], color)
|
||||
cv2.fillPoly(alpha, [points], 180)
|
||||
|
||||
layer = np.dstack([bgr, alpha])
|
||||
return layer
|
||||
|
||||
|
||||
def main():
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. 서버 상태 확인
|
||||
print("=== SAM3 서버 연결 확인 ===")
|
||||
try:
|
||||
health = requests.get(f"{SAM3_URL}/health", timeout=5)
|
||||
print(f"Server status: {health.json()}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("ERROR: SAM3 서버가 실행 중이 아닙니다!")
|
||||
print(" → start_server.bat을 먼저 실행하세요.")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. 프레임 추출
|
||||
print("\n=== 프레임 추출 ===")
|
||||
frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL)
|
||||
|
||||
# 3. 각 프레임별, 각 프롬프트별 세그멘테이션
|
||||
all_results = {}
|
||||
|
||||
for frame_idx, (fidx, frame) in enumerate(frames):
|
||||
print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===")
|
||||
|
||||
frame_b64 = encode_frame(frame)
|
||||
frame_results = {}
|
||||
composite = frame.copy()
|
||||
|
||||
for prompt in PROMPTS:
|
||||
print(f" Segmenting: {prompt}...", end=" ")
|
||||
result = predict_sam3(frame_b64, prompt)
|
||||
|
||||
if result.get("success") and result.get("data", {}).get("shapes"):
|
||||
shapes = result["data"]["shapes"]
|
||||
n = len(shapes)
|
||||
print(f"→ {n} objects found")
|
||||
frame_results[prompt] = shapes
|
||||
|
||||
# 합성 이미지에 그리기
|
||||
composite = draw_shapes_on_frame(composite, shapes, prompt)
|
||||
|
||||
# 개별 레이어 저장 (PNG with alpha)
|
||||
layer = create_layer_image(frame.shape, shapes, prompt)
|
||||
layer_name = prompt.replace(" ", "_")
|
||||
layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png"
|
||||
cv2.imwrite(str(layer_path), layer)
|
||||
else:
|
||||
print("→ no objects")
|
||||
|
||||
# 합성 이미지 저장
|
||||
composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg"
|
||||
cv2.imwrite(str(composite_path), composite)
|
||||
|
||||
# 원본 프레임 저장 (비교용)
|
||||
original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg"
|
||||
cv2.imwrite(str(original_path), frame)
|
||||
|
||||
all_results[f"frame_{fidx}"] = {
|
||||
"frame_index": fidx,
|
||||
"detections": {
|
||||
prompt: len(shapes)
|
||||
for prompt, shapes in frame_results.items()
|
||||
},
|
||||
}
|
||||
|
||||
# 4. 결과 요약 저장
|
||||
summary_path = OUTPUT_DIR / "segmentation_summary.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 5. 범례 이미지 생성
|
||||
legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8)
|
||||
for i, (prompt, color) in enumerate(COLORS.items()):
|
||||
y = i * 30 + 20
|
||||
cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1)
|
||||
cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend)
|
||||
|
||||
print(f"\n=== 완료 ===")
|
||||
print(f"결과 저장: {OUTPUT_DIR}/")
|
||||
print(f" - frame_XXXX_original.jpg : 원본 프레임")
|
||||
print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성")
|
||||
print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)")
|
||||
print(f" - segmentation_summary.json : 검출 요약")
|
||||
print(f" - legend.jpg : 색상 범례")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
618
tools/web_ui.py
Normal file
618
tools/web_ui.py
Normal file
@@ -0,0 +1,618 @@
|
||||
"""
|
||||
Railway Detection Web UI
|
||||
사용법: python tools/web_ui.py
|
||||
브라우저: http://localhost:7000
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("FastAPI/uvicorn 설치 필요: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
app = FastAPI()
|
||||
jobs: dict = {} # job_id -> {queue, status, result_path}
|
||||
|
||||
ROOT = Path(__file__).parent.parent # 프로젝트 루트
|
||||
|
||||
|
||||
# ── 미리보기 이미지 생성 ───────────────────────────────────────────────────────
|
||||
def make_preview(image_path: str, cols: int, rows: int, selected_rows: list) -> str:
|
||||
buf = np.fromfile(image_path, dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError(f"이미지 로드 실패: {image_path}")
|
||||
|
||||
H, W = img.shape[:2]
|
||||
tw = 2000
|
||||
scale = tw / W
|
||||
vis = cv2.resize(img, (tw, int(H * scale)))
|
||||
H_s, W_s = vis.shape[:2]
|
||||
|
||||
tile_w = W_s / cols
|
||||
tile_h = H_s / rows
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
idx = r * cols + c + 1
|
||||
x0, y0 = int(c * tile_w), int(r * tile_h)
|
||||
x1, y1 = min(W_s, int((c + 1) * tile_w)), min(H_s, int((r + 1) * tile_h))
|
||||
row_num = r + 1
|
||||
|
||||
if row_num in selected_rows:
|
||||
overlay = vis.copy()
|
||||
cv2.rectangle(overlay, (x0, y0), (x1, y1), (0, 200, 0), -1)
|
||||
cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis)
|
||||
cv2.rectangle(vis, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
||||
else:
|
||||
cv2.rectangle(vis, (x0, y0), (x1, y1), (180, 180, 180), 1)
|
||||
|
||||
label = str(idx)
|
||||
fs = max(0.35, tile_w / 120)
|
||||
(tw2, th2), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fs, 1)
|
||||
cx, cy = (x0 + x1) // 2, (y0 + y1) // 2
|
||||
cv2.putText(vis, label, (cx - tw2 // 2, cy + th2 // 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA)
|
||||
|
||||
# Row 번호
|
||||
y0 = int(r * tile_h)
|
||||
row_num = r + 1
|
||||
color = (0, 255, 0) if row_num in selected_rows else (160, 160, 160)
|
||||
cv2.putText(vis, f"Row{row_num}", (4, y0 + 18),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2, cv2.LINE_AA)
|
||||
|
||||
_, enc = cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 88])
|
||||
return base64.b64encode(enc).decode("utf-8")
|
||||
|
||||
|
||||
def tiles_from_rows(selected_rows: list, cols: int) -> str:
|
||||
if not selected_rows:
|
||||
return "all"
|
||||
ids = []
|
||||
for r in sorted(selected_rows):
|
||||
ids.extend(range((r - 1) * cols + 1, r * cols + 1))
|
||||
if ids and ids[-1] - ids[0] + 1 == len(ids):
|
||||
return f"{ids[0]}-{ids[-1]}"
|
||||
return ",".join(str(t) for t in ids)
|
||||
|
||||
|
||||
# ── API ───────────────────────────────────────────────────────────────────────
|
||||
@app.post("/api/preview")
|
||||
async def api_preview(data: dict):
|
||||
try:
|
||||
b64 = make_preview(
|
||||
data["image_path"],
|
||||
data.get("cols", 8),
|
||||
data.get("rows", 6),
|
||||
data.get("selected_rows", []),
|
||||
)
|
||||
return {"ok": True, "image": b64}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.post("/api/detect")
|
||||
async def api_detect(data: dict):
|
||||
job_id = str(uuid.uuid4())[:8]
|
||||
q: queue.Queue = queue.Queue()
|
||||
jobs[job_id] = {"queue": q, "status": "running", "result_path": None, "proc": None}
|
||||
|
||||
def run():
|
||||
tiles_str = tiles_from_rows(data.get("selected_rows", []), data.get("cols", 8))
|
||||
cmd = [
|
||||
sys.executable, "-u", str(ROOT / "tools" / "detect_all_objects.py"),
|
||||
"--input", data["image_path"],
|
||||
"--categories", data.get("categories", "configs/railway_zone.json"),
|
||||
"--tiles", tiles_str,
|
||||
"--cols", str(data.get("cols", 8)),
|
||||
"--rows", str(data.get("rows", 6)),
|
||||
"--overlap", str(data.get("overlap", 0.20)),
|
||||
"--conf", str(data.get("conf", 0.20)),
|
||||
"--workers", str(data.get("workers", 8)),
|
||||
"--save-json"
|
||||
]
|
||||
|
||||
env = {**os.environ, "PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
|
||||
flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
text=True, encoding="utf-8", errors="replace",
|
||||
env=env, cwd=str(ROOT), creationflags=flags,
|
||||
)
|
||||
jobs[job_id]["proc"] = proc
|
||||
for raw in proc.stdout:
|
||||
for line in raw.replace("\r", "\n").split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
q.put(("log", line))
|
||||
if line.startswith("저장:"):
|
||||
jobs[job_id]["result_path"] = line.replace("저장:", "").strip()
|
||||
proc.wait()
|
||||
if jobs[job_id]["status"] == "stopped":
|
||||
q.put(("error", "사용자가 중지했습니다"))
|
||||
elif proc.returncode == 0:
|
||||
q.put(("done", jobs[job_id]["result_path"] or "완료"))
|
||||
jobs[job_id]["status"] = "done"
|
||||
else:
|
||||
q.put(("error", f"종료코드 {proc.returncode}"))
|
||||
jobs[job_id]["status"] = "error"
|
||||
except Exception as e:
|
||||
q.put(("error", str(e)))
|
||||
jobs[job_id]["status"] = "error"
|
||||
|
||||
threading.Thread(target=run, daemon=True).start()
|
||||
return {"job_id": job_id}
|
||||
|
||||
|
||||
@app.post("/api/stop/{job_id}")
|
||||
async def api_stop(job_id: str):
|
||||
job = jobs.get(job_id)
|
||||
if not job:
|
||||
return {"ok": False, "error": "not found"}
|
||||
job["status"] = "stopped"
|
||||
proc = job.get("proc")
|
||||
if proc and proc.poll() is None:
|
||||
if os.name == "nt":
|
||||
subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)],
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
else:
|
||||
proc.terminate()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/progress/{job_id}")
|
||||
async def api_progress(job_id: str):
|
||||
if job_id not in jobs:
|
||||
return JSONResponse({"error": "not found"}, status_code=404)
|
||||
|
||||
async def event_gen():
|
||||
q = jobs[job_id]["queue"]
|
||||
while True:
|
||||
try:
|
||||
type_, msg = q.get_nowait()
|
||||
yield f"data: {json.dumps({'type': type_, 'msg': msg})}\n\n"
|
||||
if type_ in ("done", "error"):
|
||||
break
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.25)
|
||||
yield ": ping\n\n"
|
||||
|
||||
return StreamingResponse(event_gen(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@app.get("/api/result/{job_id}")
|
||||
async def api_result(job_id: str):
|
||||
job = jobs.get(job_id, {})
|
||||
path = job.get("result_path")
|
||||
if not path or not Path(path).exists():
|
||||
return {"ok": False, "error": "결과 없음"}
|
||||
buf = np.fromfile(path, dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
return {"ok": False, "error": "이미지 로드 실패"}
|
||||
h, w = img.shape[:2]
|
||||
if max(h, w) > 3000:
|
||||
s = 3000 / max(h, w)
|
||||
img = cv2.resize(img, (int(w * s), int(h * s)))
|
||||
_, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
return {"ok": True, "image": base64.b64encode(enc).decode(), "path": path}
|
||||
|
||||
|
||||
@app.post("/api/open-folder")
|
||||
async def api_open_folder(data: dict):
|
||||
path = Path(data.get("path", "output/detect"))
|
||||
folder = path.parent if path.is_file() else path
|
||||
if os.name == "nt":
|
||||
os.startfile(str(folder.resolve()))
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── HTML ──────────────────────────────────────────────────────────────────────
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index():
|
||||
return HTML
|
||||
|
||||
|
||||
HTML = r"""<!DOCTYPE html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Railway Detection UI</title>
|
||||
<style>
|
||||
*{box-sizing:border-box;margin:0;padding:0}
|
||||
body{font-family:'Segoe UI',sans-serif;background:#1a1a2e;color:#e0e0e0;height:100vh;display:flex;flex-direction:column;overflow:hidden}
|
||||
header{background:#16213e;padding:10px 20px;border-bottom:2px solid #0f3460;flex-shrink:0}
|
||||
header h1{font-size:1.1rem;color:#e94560;letter-spacing:1px}
|
||||
.tabs{display:flex;background:#16213e;border-bottom:1px solid #0f3460;flex-shrink:0}
|
||||
.tab-btn{padding:9px 26px;background:none;border:none;color:#888;cursor:pointer;font-size:0.92rem;border-bottom:3px solid transparent;transition:all .2s}
|
||||
.tab-btn.active{color:#e94560;border-bottom-color:#e94560}
|
||||
.tab-content{display:none;flex:1;overflow:hidden;min-height:0}
|
||||
.tab-content.active{display:flex}
|
||||
|
||||
/* 입력 탭 */
|
||||
#tab-input{flex-direction:row}
|
||||
.controls{width:320px;min-width:260px;background:#16213e;padding:14px;overflow-y:auto;border-right:1px solid #0f3460;display:flex;flex-direction:column;gap:12px;flex-shrink:0}
|
||||
.preview-area{flex:1;background:#0d0d1f;overflow:hidden;position:relative;cursor:grab;min-height:0}
|
||||
.preview-area.drag{cursor:grabbing}
|
||||
.preview-area img{position:absolute;transform-origin:0 0;user-select:none;pointer-events:none}
|
||||
.preview-area .hint{position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);color:#444;font-size:.9rem}
|
||||
.preview-hud{position:absolute;top:8px;right:10px;font-size:.74rem;color:#555;background:rgba(0,0,0,.35);padding:3px 8px;border-radius:4px;pointer-events:none;z-index:2}
|
||||
.preview-fit{position:absolute;top:8px;left:10px;font-size:.78rem;background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:4px 10px;border-radius:4px;cursor:pointer;z-index:2}
|
||||
.preview-fit:hover{background:#1a5299}
|
||||
.fg{display:flex;flex-direction:column;gap:4px}
|
||||
.fg>label{font-size:.75rem;color:#888;font-weight:600;text-transform:uppercase;letter-spacing:.5px}
|
||||
.path-row{display:flex;gap:6px}
|
||||
.path-row input{flex:1;min-width:0}
|
||||
input[type=text],input[type=number]{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0;padding:6px 9px;border-radius:4px;font-size:.88rem;width:100%}
|
||||
input[type=text]:focus,input[type=number]:focus{outline:none;border-color:#e94560}
|
||||
.row-grid{display:grid;grid-template-columns:repeat(3,1fr);gap:6px}
|
||||
.row-lbl{display:flex;align-items:center;gap:6px;padding:7px 9px;background:#0f3460;border:1px solid #2a4a7a;border-radius:4px;cursor:pointer;transition:all .15s;font-size:.85rem}
|
||||
.row-lbl:hover{border-color:#0c6}
|
||||
.row-lbl.on{border-color:#0f6;background:#0d3320;color:#0f6}
|
||||
.row-lbl input{width:14px;height:14px;cursor:pointer;accent-color:#0f6}
|
||||
.params-grid{display:grid;grid-template-columns:1fr 1fr;gap:8px}
|
||||
.pi{display:flex;flex-direction:column;gap:3px}
|
||||
.pi label{font-size:.72rem;color:#888}
|
||||
.btn-load{background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:6px 13px;border-radius:4px;cursor:pointer;font-size:.84rem;white-space:nowrap;transition:background .2s}
|
||||
.btn-load:hover{background:#1a5299}
|
||||
.btn-run{background:#e94560;color:#fff;border:none;padding:11px;border-radius:6px;cursor:pointer;font-size:.98rem;font-weight:700;width:100%;margin-top:2px;transition:background .2s}
|
||||
.btn-run:hover{background:#c73652}
|
||||
.btn-run:disabled{background:#555;cursor:not-allowed}
|
||||
.btn-stop{background:#555;color:#fff;border:none;padding:11px;border-radius:6px;cursor:pointer;font-size:.98rem;font-weight:700;width:100%;margin-top:2px;display:none}
|
||||
.btn-stop.visible{display:block}
|
||||
.btn-stop:hover{background:#c73652}
|
||||
|
||||
/* 결과 탭 */
|
||||
#tab-result{flex-direction:column}
|
||||
.res-toolbar{background:#16213e;padding:9px 14px;display:flex;align-items:center;gap:10px;border-bottom:1px solid #0f3460;flex-shrink:0}
|
||||
.btn-act{background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:6px 13px;border-radius:4px;cursor:pointer;font-size:.84rem;transition:background .2s}
|
||||
.btn-act:hover{background:#1a5299}
|
||||
#result-path{font-size:.78rem;color:#888;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;flex:1;min-width:0}
|
||||
.viewer{flex:1;overflow:hidden;cursor:grab;position:relative;background:#0a0a1a;min-height:0}
|
||||
.viewer.drag{cursor:grabbing}
|
||||
.viewer img{position:absolute;transform-origin:0 0;user-select:none;pointer-events:none}
|
||||
.viewer-hint{display:flex;align-items:center;justify-content:center;height:100%;color:#333;font-size:.95rem}
|
||||
|
||||
/* 진행 */
|
||||
.prog-section{background:#16213e;padding:8px 16px;border-top:1px solid #0f3460;flex-shrink:0}
|
||||
.prog-track{background:#0a0a1a;border-radius:4px;height:7px;overflow:hidden;margin-bottom:5px}
|
||||
.prog-fill{height:100%;background:linear-gradient(90deg,#e94560,#0f6);width:0%;transition:width .4s;border-radius:4px}
|
||||
.prog-log{font-size:.76rem;color:#777;white-space:nowrap;overflow:hidden;text-overflow:ellipsis}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header><h1>🛤 Railway Detection UI</h1></header>
|
||||
|
||||
<div class="tabs">
|
||||
<button class="tab-btn active" id="tbtn-input" onclick="switchTab('input')">입력</button>
|
||||
<button class="tab-btn" id="tbtn-result" onclick="switchTab('result')">결과</button>
|
||||
</div>
|
||||
|
||||
<div id="tab-input" class="tab-content active">
|
||||
<div class="controls">
|
||||
|
||||
<div class="fg">
|
||||
<label>이미지 경로</label>
|
||||
<div class="path-row">
|
||||
<input type="text" id="image-path" placeholder="경로 입력 후 Enter 또는 버튼">
|
||||
<button class="btn-load" onclick="loadPreview()">미리보기</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="fg">
|
||||
<label>Row 선택 (철도 구역)</label>
|
||||
<div class="row-grid" id="row-grid"></div>
|
||||
</div>
|
||||
|
||||
<div class="fg">
|
||||
<label>Categories JSON</label>
|
||||
<input type="text" id="categories" value="configs/railway_zone.json">
|
||||
</div>
|
||||
|
||||
<div class="fg">
|
||||
<label>파라미터</label>
|
||||
<div class="params-grid">
|
||||
<div class="pi"><label>Cols</label><input type="number" id="cols" value="8" min="1" onchange="loadPreview()"></div>
|
||||
<div class="pi"><label>Rows</label><input type="number" id="rows" value="6" min="1" onchange="initRows();loadPreview()"></div>
|
||||
<div class="pi"><label>Overlap</label><input type="number" id="overlap" value="0.20" step="0.05" min="0" max="0.5"></div>
|
||||
<div class="pi"><label>Conf</label><input type="number" id="conf" value="0.20" step="0.05" min="0.05" max="1"></div>
|
||||
<div class="pi"><label>Workers</label><input type="number" id="workers" value="8" min="1" max="32"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button class="btn-run" id="btn-run" onclick="startDetect()">▶ 검출 시작</button>
|
||||
<button class="btn-stop" id="btn-stop" onclick="stopDetect()">■ 중지</button>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="preview-area" id="preview-area">
|
||||
<button class="preview-fit" onclick="resetPreviewZoom()">화면 맞춤</button>
|
||||
<span class="preview-hud">휠: 확대/축소 | 드래그: 이동</span>
|
||||
<span class="hint" id="preview-hint">이미지 경로 입력 후 미리보기 클릭</span>
|
||||
<img id="preview-img" style="display:none">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="tab-result" class="tab-content">
|
||||
<div class="res-toolbar">
|
||||
<button class="btn-act" onclick="openFolder()">📁 출력 폴더 열기</button>
|
||||
<span id="result-path">결과 없음</span>
|
||||
<button class="btn-act" onclick="resetZoom()" style="margin-left:auto">화면 맞춤</button>
|
||||
<span style="font-size:.76rem;color:#555">휠: 확대/축소 | 드래그: 이동</span>
|
||||
</div>
|
||||
<div class="viewer" id="viewer">
|
||||
<div class="viewer-hint" id="viewer-hint">검출 완료 후 결과가 표시됩니다</div>
|
||||
<img id="result-img" style="display:none">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="prog-section">
|
||||
<div class="prog-track"><div class="prog-fill" id="prog-fill"></div></div>
|
||||
<div class="prog-log" id="prog-log">대기 중</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let resultPath = null;
|
||||
let sc = 1, tx = 0, ty = 0, dragging = false, sx, sy;
|
||||
|
||||
// ── 탭 전환
|
||||
function switchTab(name) {
|
||||
document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
|
||||
document.querySelectorAll('.tab-btn').forEach(el => el.classList.remove('active'));
|
||||
document.getElementById('tab-' + name).classList.add('active');
|
||||
document.getElementById('tbtn-' + name).classList.add('active');
|
||||
}
|
||||
|
||||
// ── Row 체크박스
|
||||
function initRows() {
|
||||
const numRows = parseInt(document.getElementById('rows').value) || 6;
|
||||
const grid = document.getElementById('row-grid');
|
||||
grid.innerHTML = '';
|
||||
for (let r = 1; r <= numRows; r++) {
|
||||
const lbl = document.createElement('label');
|
||||
lbl.className = 'row-lbl';
|
||||
lbl.innerHTML = `<input type="checkbox" value="${r}" onchange="onRowChange(this)"><span>Row ${r}</span>`;
|
||||
grid.appendChild(lbl);
|
||||
}
|
||||
}
|
||||
|
||||
function onRowChange(cb) {
|
||||
cb.parentElement.classList.toggle('on', cb.checked);
|
||||
loadPreview();
|
||||
}
|
||||
|
||||
function getSelectedRows() {
|
||||
return Array.from(document.querySelectorAll('#row-grid input:checked')).map(c => parseInt(c.value));
|
||||
}
|
||||
|
||||
// ── 미리보기
|
||||
async function loadPreview() {
|
||||
const imgPath = document.getElementById('image-path').value.trim();
|
||||
if (!imgPath) return;
|
||||
const hint = document.getElementById('preview-hint');
|
||||
const pimg = document.getElementById('preview-img');
|
||||
hint.style.display = 'block';
|
||||
hint.style.color = '';
|
||||
hint.textContent = '로딩 중...';
|
||||
try {
|
||||
const res = await post('/api/preview', {
|
||||
image_path: imgPath,
|
||||
cols: +document.getElementById('cols').value,
|
||||
rows: +document.getElementById('rows').value,
|
||||
selected_rows: getSelectedRows()
|
||||
});
|
||||
if (res.ok) {
|
||||
const keepZoom = pimg.style.display === 'block';
|
||||
pimg.onload = () => { if (!keepZoom) resetPreviewZoom(); else applyPreviewT(); };
|
||||
pimg.src = 'data:image/jpeg;base64,' + res.image;
|
||||
pimg.style.display = 'block';
|
||||
hint.style.display = 'none';
|
||||
} else {
|
||||
hint.style.color = '#e94560';
|
||||
hint.textContent = '오류: ' + res.error;
|
||||
}
|
||||
} catch(e) {
|
||||
hint.style.color = '#e94560';
|
||||
hint.textContent = String(e);
|
||||
}
|
||||
}
|
||||
|
||||
// ── 검출 시작
|
||||
async function startDetect() {
|
||||
const rows = getSelectedRows();
|
||||
if (!rows.length) { alert('Row를 하나 이상 선택하세요'); return; }
|
||||
const imgPath = document.getElementById('image-path').value.trim();
|
||||
if (!imgPath) { alert('이미지 경로를 입력하세요'); return; }
|
||||
|
||||
document.getElementById('btn-run').disabled = true;
|
||||
document.getElementById('btn-stop').classList.add('visible');
|
||||
document.getElementById('prog-fill').style.width = '0%';
|
||||
document.getElementById('prog-log').textContent = '시작 중...';
|
||||
|
||||
const res = await post('/api/detect', {
|
||||
image_path: imgPath,
|
||||
selected_rows: rows,
|
||||
cols: +document.getElementById('cols').value,
|
||||
rows: +document.getElementById('rows').value,
|
||||
overlap: +document.getElementById('overlap').value,
|
||||
conf: +document.getElementById('conf').value,
|
||||
workers: +document.getElementById('workers').value,
|
||||
categories: document.getElementById('categories').value.trim()
|
||||
});
|
||||
listenProgress(res.job_id);
|
||||
}
|
||||
|
||||
// ── SSE 진행
|
||||
function listenProgress(jobId) {
|
||||
const es = new EventSource(`/api/progress/${jobId}`);
|
||||
es.onmessage = async e => {
|
||||
const {type, msg} = JSON.parse(e.data);
|
||||
document.getElementById('prog-log').textContent = msg || '';
|
||||
if (type === 'log') {
|
||||
const m = msg.match(/타일\s+(\d+)\/(\d+)/);
|
||||
if (m) {
|
||||
document.getElementById('prog-fill').style.width =
|
||||
Math.round(+m[1] / +m[2] * 100) + '%';
|
||||
}
|
||||
} else if (type === 'done') {
|
||||
es.close();
|
||||
document.getElementById('prog-fill').style.width = '100%';
|
||||
document.getElementById('prog-log').textContent = '완료! 결과 탭으로 이동합니다.';
|
||||
setRunning(false);
|
||||
await showResult(jobId);
|
||||
} else if (type === 'error') {
|
||||
es.close();
|
||||
document.getElementById('prog-log').textContent = msg;
|
||||
setRunning(false);
|
||||
}
|
||||
};
|
||||
es.onerror = () => { es.close(); setRunning(false); };
|
||||
}
|
||||
|
||||
// ── 결과 표시
|
||||
async function showResult(jobId) {
|
||||
const res = await fetch(`/api/result/${jobId}`).then(r => r.json());
|
||||
if (!res.ok) return;
|
||||
resultPath = res.path;
|
||||
document.getElementById('result-path').textContent = res.path;
|
||||
const img = document.getElementById('result-img');
|
||||
img.src = 'data:image/jpeg;base64,' + res.image;
|
||||
img.style.display = 'block';
|
||||
document.getElementById('viewer-hint').style.display = 'none';
|
||||
img.onload = resetZoom;
|
||||
switchTab('result');
|
||||
}
|
||||
|
||||
// ── Preview Zoom / Pan
|
||||
const previewArea = document.getElementById('preview-area');
|
||||
const pimg = document.getElementById('preview-img');
|
||||
let psc = 1, ptx = 0, pty = 0, pdragging = false, psx, psy;
|
||||
|
||||
previewArea.addEventListener('wheel', e => {
|
||||
if (pimg.style.display !== 'block') return;
|
||||
e.preventDefault();
|
||||
const rect = previewArea.getBoundingClientRect();
|
||||
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
|
||||
const f = e.deltaY < 0 ? 1.15 : 0.87;
|
||||
const ns = Math.min(Math.max(0.05, psc * f), 40);
|
||||
ptx = mx - (mx - ptx) * (ns / psc);
|
||||
pty = my - (my - pty) * (ns / psc);
|
||||
psc = ns;
|
||||
applyPreviewT();
|
||||
}, {passive: false});
|
||||
|
||||
previewArea.addEventListener('mousedown', e => {
|
||||
if (pimg.style.display !== 'block') return;
|
||||
if (e.target.classList.contains('preview-fit')) return;
|
||||
pdragging = true; psx = e.clientX - ptx; psy = e.clientY - pty;
|
||||
previewArea.classList.add('drag');
|
||||
});
|
||||
document.addEventListener('mousemove', e => {
|
||||
if (!pdragging) return;
|
||||
ptx = e.clientX - psx; pty = e.clientY - psy; applyPreviewT();
|
||||
});
|
||||
document.addEventListener('mouseup', () => { pdragging = false; previewArea.classList.remove('drag'); });
|
||||
|
||||
function applyPreviewT() { pimg.style.transform = `translate(${ptx}px,${pty}px) scale(${psc})`; }
|
||||
|
||||
function resetPreviewZoom() {
|
||||
const vw = previewArea.clientWidth, vh = previewArea.clientHeight;
|
||||
if (!pimg.naturalWidth) return;
|
||||
psc = Math.min(vw / pimg.naturalWidth, vh / pimg.naturalHeight) * 0.98;
|
||||
ptx = (vw - pimg.naturalWidth * psc) / 2;
|
||||
pty = (vh - pimg.naturalHeight * psc) / 2;
|
||||
applyPreviewT();
|
||||
}
|
||||
|
||||
// ── Result Zoom / Pan
|
||||
const viewer = document.getElementById('viewer');
|
||||
const rimg = document.getElementById('result-img');
|
||||
|
||||
viewer.addEventListener('wheel', e => {
|
||||
e.preventDefault();
|
||||
const rect = viewer.getBoundingClientRect();
|
||||
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
|
||||
const f = e.deltaY < 0 ? 1.15 : 0.87;
|
||||
const ns = Math.min(Math.max(0.05, sc * f), 40);
|
||||
tx = mx - (mx - tx) * (ns / sc);
|
||||
ty = my - (my - ty) * (ns / sc);
|
||||
sc = ns;
|
||||
applyT();
|
||||
}, {passive: false});
|
||||
|
||||
viewer.addEventListener('mousedown', e => {
|
||||
dragging = true; sx = e.clientX - tx; sy = e.clientY - ty;
|
||||
viewer.classList.add('drag');
|
||||
});
|
||||
document.addEventListener('mousemove', e => {
|
||||
if (!dragging) return;
|
||||
tx = e.clientX - sx; ty = e.clientY - sy; applyT();
|
||||
});
|
||||
document.addEventListener('mouseup', () => { dragging = false; viewer.classList.remove('drag'); });
|
||||
|
||||
function applyT() { rimg.style.transform = `translate(${tx}px,${ty}px) scale(${sc})`; }
|
||||
|
||||
function resetZoom() {
|
||||
const vw = viewer.clientWidth, vh = viewer.clientHeight;
|
||||
if (!rimg.naturalWidth) return;
|
||||
sc = Math.min(vw / rimg.naturalWidth, vh / rimg.naturalHeight) * 0.95;
|
||||
tx = (vw - rimg.naturalWidth * sc) / 2;
|
||||
ty = (vh - rimg.naturalHeight * sc) / 2;
|
||||
applyT();
|
||||
}
|
||||
|
||||
// ── 중지
|
||||
let currentES = null;
|
||||
async function stopDetect() {
|
||||
if (!currentJobId) return;
|
||||
await post('/api/stop/' + currentJobId, {});
|
||||
document.getElementById('prog-log').textContent = '중지 요청됨...';
|
||||
}
|
||||
|
||||
function setRunning(on) {
|
||||
document.getElementById('btn-run').disabled = on;
|
||||
document.getElementById('btn-stop').classList.toggle('visible', on);
|
||||
}
|
||||
|
||||
// ── 폴더 열기
|
||||
async function openFolder() {
|
||||
await post('/api/open-folder', {path: resultPath || 'output/detect'});
|
||||
}
|
||||
|
||||
// ── 유틸
|
||||
async function post(url, data) {
|
||||
const r = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify(data)});
|
||||
return r.json();
|
||||
}
|
||||
|
||||
// ── 초기화
|
||||
initRows();
|
||||
document.getElementById('image-path').addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') loadPreview();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Railway Detection UI 시작: http://localhost:7000")
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000, log_level="warning")
|
||||
333
tools/yoloworld_sam3_pipeline.py
Normal file
333
tools/yoloworld_sam3_pipeline.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
YOLO-World + SAM3 반자동 레이블링 파이프라인
|
||||
=============================================
|
||||
YOLO-World로 bbox 검출 → SAM3로 polygon mask 생성 → 시각화 저장
|
||||
|
||||
사용법:
|
||||
python tools/yoloworld_sam3_pipeline.py --input sample/rail --output output/labeled
|
||||
python tools/yoloworld_sam3_pipeline.py --input sample/rail/frame_00000.jpg
|
||||
|
||||
SAM3 서버가 실행 중이어야 합니다:
|
||||
cd X-AnyLabeling-Server && uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
# ── 검출 대상 클래스 (YOLO-World 텍스트 프롬프트) ──────────────────────────
|
||||
TARGET_CLASSES = [
|
||||
"catenary pole", # 전철주 (세로 기둥)
|
||||
"catenary arm", # 전철주 (가로 암)
|
||||
"junction box", # 통신/전기 박스
|
||||
"fence", # 펜스
|
||||
]
|
||||
|
||||
# 클래스별 색상 (BGR)
|
||||
CLASS_COLORS = {
|
||||
"catenary pole": (0, 200, 255), # 주황
|
||||
"catenary arm": (0, 100, 255), # 빨강
|
||||
"junction box": (255, 180, 0), # 파랑
|
||||
"fence": (0, 255, 100), # 초록
|
||||
}
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
MODEL_ID = "segment_anything_3"
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# YOLO-World 초기화
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def load_yolo_world(model_size: str = "s"):
|
||||
"""YOLO-World 모델 로드 (자동 다운로드)."""
|
||||
from ultralytics import YOLOWorld
|
||||
|
||||
model_name = f"yolov8{model_size}-worldv2.pt"
|
||||
print(f"[YOLO-World] 모델 로드: {model_name}")
|
||||
model = YOLOWorld(model_name)
|
||||
model.set_classes(TARGET_CLASSES)
|
||||
print(f"[YOLO-World] 검출 클래스: {TARGET_CLASSES}")
|
||||
return model
|
||||
|
||||
|
||||
def detect_with_yoloworld(model, image_bgr: np.ndarray, conf: float = 0.15):
|
||||
"""YOLO-World로 bbox 검출. [(x1,y1,x2,y2,conf,class_name), ...] 반환."""
|
||||
results = model.predict(image_bgr, conf=conf, verbose=False)
|
||||
detections = []
|
||||
if results and len(results) > 0:
|
||||
r = results[0]
|
||||
boxes = r.boxes
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
score = float(box.conf[0].cpu())
|
||||
cls_idx = int(box.cls[0].cpu())
|
||||
cls_name = TARGET_CLASSES[cls_idx] if cls_idx < len(TARGET_CLASSES) else f"class_{cls_idx}"
|
||||
detections.append((float(x1), float(y1), float(x2), float(y2), score, cls_name))
|
||||
return detections
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# SAM3 서버 호출
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def encode_image(image_bgr: np.ndarray) -> str:
|
||||
"""이미지를 base64 문자열로 인코딩."""
|
||||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return base64.b64encode(buf).decode("utf-8")
|
||||
|
||||
|
||||
def sam3_segment(image_bgr: np.ndarray, boxes: list, conf_threshold: float = 0.25):
|
||||
"""SAM3 서버에 bbox 전달 → polygon masks 반환.
|
||||
|
||||
Args:
|
||||
image_bgr: 원본 이미지
|
||||
boxes: [(x1,y1,x2,y2,score,class_name), ...]
|
||||
conf_threshold: SAM3 신뢰도 임계값
|
||||
|
||||
Returns:
|
||||
shapes: [{"label": str, "points": [[x,y],...], "score": float}, ...]
|
||||
"""
|
||||
marks = [
|
||||
{
|
||||
"type": "rectangle",
|
||||
"label": 1,
|
||||
"data": [b[0], b[1], b[2], b[3]],
|
||||
}
|
||||
for b in boxes
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": MODEL_ID,
|
||||
"image": encode_image(image_bgr),
|
||||
"params": {
|
||||
"marks": marks,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
"conf_threshold": conf_threshold,
|
||||
"epsilon_factor": 0.002,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f" [ERROR] SAM3 서버에 연결할 수 없습니다: {SAM3_SERVER}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f" [ERROR] SAM3 호출 실패: {e}")
|
||||
return []
|
||||
|
||||
if data.get("status") != "success":
|
||||
print(f" [ERROR] SAM3 응답 오류: {data}")
|
||||
return []
|
||||
|
||||
return data.get("data", {}).get("shapes", [])
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 시각화
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def draw_results(image_bgr: np.ndarray, detections: list, shapes: list) -> np.ndarray:
|
||||
"""bbox + mask를 이미지에 그리기."""
|
||||
vis = image_bgr.copy()
|
||||
overlay = image_bgr.copy()
|
||||
|
||||
# SAM3 polygon masks 그리기
|
||||
for shape in shapes:
|
||||
pts = np.array(shape["points"], dtype=np.int32)
|
||||
# 첫점=끝점이면 마지막 제거
|
||||
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
|
||||
pts = pts[:-1]
|
||||
label = shape.get("label", "unknown")
|
||||
color = CLASS_COLORS.get(label, (200, 200, 200))
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.polylines(vis, [pts], True, color, 2)
|
||||
|
||||
cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis)
|
||||
|
||||
# YOLO-World bbox + 라벨 그리기
|
||||
for x1, y1, x2, y2, score, cls_name in detections:
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
color = CLASS_COLORS.get(cls_name, (200, 200, 200))
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
|
||||
label_text = f"{cls_name} {score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
||||
cv2.rectangle(vis, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1)
|
||||
cv2.putText(vis, label_text, (x1 + 2, y1 - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
|
||||
|
||||
# 범례
|
||||
legend_y = 20
|
||||
for cls_name, color in CLASS_COLORS.items():
|
||||
cv2.rectangle(vis, (10, legend_y), (25, legend_y + 15), color, -1)
|
||||
cv2.putText(vis, cls_name, (30, legend_y + 13),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
legend_y += 22
|
||||
|
||||
return vis
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 단일 이미지 처리
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def process_image(yolo_model, image_path: Path, output_dir: Path,
|
||||
yolo_conf: float, sam3_conf: float, skip_sam3: bool) -> dict:
|
||||
"""이미지 1장 처리. 결과 딕셔너리 반환."""
|
||||
image_bgr = cv2.imread(str(image_path))
|
||||
if image_bgr is None:
|
||||
print(f" [SKIP] 이미지 읽기 실패: {image_path}")
|
||||
return {}
|
||||
|
||||
h, w = image_bgr.shape[:2]
|
||||
print(f"\n 이미지: {image_path.name} ({w}x{h})")
|
||||
|
||||
# ── Step 1: YOLO-World 검출 ────────────────────────────────────────────
|
||||
detections = detect_with_yoloworld(yolo_model, image_bgr, conf=yolo_conf)
|
||||
print(f" YOLO-World: {len(detections)}개 검출")
|
||||
for d in detections:
|
||||
print(f" [{d[5]}] conf={d[4]:.3f} bbox=({d[0]:.0f},{d[1]:.0f},{d[2]:.0f},{d[3]:.0f})")
|
||||
|
||||
shapes = []
|
||||
if not skip_sam3 and detections:
|
||||
# ── Step 2: SAM3 mask 생성 ─────────────────────────────────────────
|
||||
print(f" SAM3: {len(detections)}개 bbox → mask 요청 중...")
|
||||
shapes = sam3_segment(image_bgr, detections, conf_threshold=sam3_conf)
|
||||
print(f" SAM3: {len(shapes)}개 mask 반환")
|
||||
for s in shapes:
|
||||
pts_count = len(s.get("points", []))
|
||||
print(f" [{s.get('label')}] score={s.get('score', 0):.3f} points={pts_count}")
|
||||
|
||||
# ── Step 3: 시각화 저장 ───────────────────────────────────────────────
|
||||
vis = draw_results(image_bgr, detections, shapes)
|
||||
output_path = output_dir / f"{image_path.stem}_result.jpg"
|
||||
cv2.imwrite(str(output_path), vis)
|
||||
print(f" 저장: {output_path}")
|
||||
|
||||
return {
|
||||
"image": image_path.name,
|
||||
"detections": [
|
||||
{"class": d[5], "conf": round(d[4], 3),
|
||||
"bbox": [round(d[0]), round(d[1]), round(d[2]), round(d[3])]}
|
||||
for d in detections
|
||||
],
|
||||
"masks": len(shapes),
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 메인
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="YOLO-World + SAM3 파이프라인")
|
||||
parser.add_argument("--input", default="sample/rail",
|
||||
help="이미지 파일 또는 폴더 경로")
|
||||
parser.add_argument("--output", default="output/yoloworld_sam3",
|
||||
help="결과 저장 폴더")
|
||||
parser.add_argument("--model-size", default="s", choices=["s", "m", "l", "x"],
|
||||
help="YOLO-World 모델 크기 (s/m/l/x)")
|
||||
parser.add_argument("--yolo-conf", type=float, default=0.10,
|
||||
help="YOLO-World 검출 임계값 (기본 0.10)")
|
||||
parser.add_argument("--sam3-conf", type=float, default=0.20,
|
||||
help="SAM3 마스크 임계값 (기본 0.20)")
|
||||
parser.add_argument("--skip-sam3", action="store_true",
|
||||
help="SAM3 건너뛰고 YOLO-World bbox만 시각화")
|
||||
parser.add_argument("--server", default="http://localhost:8000",
|
||||
help="SAM3 서버 주소")
|
||||
args = parser.parse_args()
|
||||
|
||||
global SAM3_SERVER
|
||||
SAM3_SERVER = args.server
|
||||
|
||||
# SAM3 서버 상태 확인
|
||||
if not args.skip_sam3:
|
||||
try:
|
||||
resp = requests.get(f"{SAM3_SERVER}/health", timeout=5)
|
||||
if resp.status_code == 200:
|
||||
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
|
||||
else:
|
||||
print(f"[WARN] SAM3 서버 응답 이상 (status={resp.status_code})")
|
||||
except Exception:
|
||||
print(f"[WARN] SAM3 서버 연결 실패 ({SAM3_SERVER}). --skip-sam3 로 bbox만 볼 수 있음.")
|
||||
ans = input("계속 진행하시겠습니까? (y/N): ").strip().lower()
|
||||
if ans != "y":
|
||||
sys.exit(0)
|
||||
|
||||
# 입력 경로 처리
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
image_files = [input_path]
|
||||
elif input_path.is_dir():
|
||||
image_files = sorted(
|
||||
list(input_path.glob("*.jpg")) +
|
||||
list(input_path.glob("*.jpeg")) +
|
||||
list(input_path.glob("*.png"))
|
||||
)
|
||||
else:
|
||||
print(f"[ERROR] 입력 경로를 찾을 수 없습니다: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not image_files:
|
||||
print(f"[ERROR] 이미지 파일 없음: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n총 {len(image_files)}개 이미지 처리 예정")
|
||||
|
||||
# 출력 폴더 생성
|
||||
output_dir = Path(args.output)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO-World 로드
|
||||
yolo_model = load_yolo_world(args.model_size)
|
||||
|
||||
# 처리
|
||||
summary = []
|
||||
for img_path in image_files:
|
||||
result = process_image(
|
||||
yolo_model, img_path, output_dir,
|
||||
yolo_conf=args.yolo_conf,
|
||||
sam3_conf=args.sam3_conf,
|
||||
skip_sam3=args.skip_sam3,
|
||||
)
|
||||
if result:
|
||||
summary.append(result)
|
||||
|
||||
# 요약 저장
|
||||
summary_path = output_dir / "summary.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 통계 출력
|
||||
print("\n" + "="*50)
|
||||
print("처리 완료 요약")
|
||||
print("="*50)
|
||||
total_det = sum(len(r["detections"]) for r in summary)
|
||||
total_mask = sum(r["masks"] for r in summary)
|
||||
print(f"처리 이미지: {len(summary)}장")
|
||||
print(f"YOLO 검출: {total_det}개 (평균 {total_det/max(len(summary),1):.1f}/장)")
|
||||
print(f"SAM3 마스크: {total_mask}개 (평균 {total_mask/max(len(summary),1):.1f}/장)")
|
||||
|
||||
# 클래스별 집계
|
||||
class_counts: dict = {}
|
||||
for r in summary:
|
||||
for d in r["detections"]:
|
||||
cls = d["class"]
|
||||
class_counts[cls] = class_counts.get(cls, 0) + 1
|
||||
if class_counts:
|
||||
print("\n클래스별 검출 수:")
|
||||
for cls, cnt in sorted(class_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {cls:20s}: {cnt}개")
|
||||
|
||||
print(f"\n결과 저장: {output_dir.resolve()}")
|
||||
print(f"요약 JSON: {summary_path.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user