536 lines
23 KiB
Python
536 lines
23 KiB
Python
"""SAM 원본 폴리곤의 skeleton을 까만색으로 이미지에 겹쳐 표시.
|
|
union mask 기준 skeleton을 그리므로, 겹쳐있는 fragment들의 합쳐진 골격이 보임.
|
|
|
|
사용:
|
|
python tools/render_skeleton_overlay.py \
|
|
--image data/역사이미지/slope/DJI_20260306113842_0006.JPG \
|
|
--label output/autolabel/raw_0006/labels/DJI_20260306113842_0006.txt \
|
|
--output output/autolabel/raw_0006/vis_skeleton.jpg
|
|
"""
|
|
import argparse
|
|
import base64
|
|
import requests
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
SAM3_SERVER = "http://localhost:8000"
|
|
SAM3_MODEL_ID = "segment_anything_3"
|
|
|
|
SUB_CATEGORIES = [
|
|
{"name": "pole_shaft", "name_kr": "전주기둥",
|
|
"prompt": "vertical steel pole shaft, cylindrical mast, tubular pole body",
|
|
"keywords": ["pole", "shaft", "mast", "tubular"],
|
|
"color_bgr": [0, 200, 255]},
|
|
{"name": "bracket_arm", "name_kr": "브라켓/암",
|
|
"prompt": "horizontal cantilever arm, bracket arm, cross arm, pole cross arm",
|
|
"keywords": ["bracket", "cantilever", "cross arm", "arm"],
|
|
"color_bgr": [255, 100, 0]},
|
|
{"name": "insulator", "name_kr": "애자",
|
|
"prompt": "ceramic insulator, glass insulator, suspension insulator, string of insulators",
|
|
"keywords": ["insulator"],
|
|
"color_bgr": [0, 255, 255]},
|
|
{"name": "wire", "name_kr": "가선/조가선",
|
|
"prompt": "overhead contact wire, catenary wire, messenger wire, electric cable",
|
|
"keywords": ["wire", "cable", "catenary"],
|
|
"color_bgr": [200, 200, 255]},
|
|
]
|
|
|
|
|
|
def _sam3_call(crop_bgr: np.ndarray, prompt: str, conf: float = 0.15) -> list:
|
|
h, w = crop_bgr.shape[:2]
|
|
scale = 1.0
|
|
if max(h, w) > 1280:
|
|
scale = 1280 / max(h, w)
|
|
crop_bgr = cv2.resize(crop_bgr, (int(w * scale), int(h * scale)))
|
|
_, buf = cv2.imencode(".jpg", crop_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
|
b64 = base64.b64encode(buf).decode()
|
|
payload = {
|
|
"model": SAM3_MODEL_ID,
|
|
"image": b64,
|
|
"params": {"text_prompt": prompt, "conf_threshold": conf,
|
|
"show_masks": True, "show_boxes": False},
|
|
}
|
|
try:
|
|
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
|
|
r.raise_for_status()
|
|
resp = r.json()
|
|
if not resp.get("success"):
|
|
return []
|
|
shapes = resp.get("data", {}).get("shapes", [])
|
|
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
|
if scale < 1.0:
|
|
inv = 1.0 / scale
|
|
for s in shapes:
|
|
if s.get("shape_type") == "polygon":
|
|
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
|
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
|
except Exception as e:
|
|
print(f" SAM3 오류: {e}")
|
|
return []
|
|
|
|
|
|
def _label_to_subcat(label_str: str):
|
|
ll = label_str.lower()
|
|
for cat in SUB_CATEGORIES:
|
|
if any(kw in ll for kw in cat["keywords"]):
|
|
return cat
|
|
return None
|
|
|
|
|
|
def _subdetect_group(img: np.ndarray, group_polys: list, pad: int = 80) -> int:
|
|
"""그룹 폴리곤들의 union bbox crop → SAM3 sub-detect → img 오버레이."""
|
|
all_pts = np.vstack(group_polys)
|
|
H, W = img.shape[:2]
|
|
x0 = max(0, int(all_pts[:, 0].min()) - pad)
|
|
y0 = max(0, int(all_pts[:, 1].min()) - pad)
|
|
x1 = min(W, int(all_pts[:, 0].max()) + pad)
|
|
y1 = min(H, int(all_pts[:, 1].max()) + pad)
|
|
|
|
crop = img[y0:y1, x0:x1].copy()
|
|
combined_prompt = ", ".join(c["prompt"] for c in SUB_CATEGORIES)
|
|
shapes = _sam3_call(crop, combined_prompt)
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
for s in shapes:
|
|
cat = _label_to_subcat(s.get("label", ""))
|
|
color = tuple(cat["color_bgr"]) if cat else (128, 128, 128)
|
|
pts = np.array([[int(px + x0), int(py + y0)] for px, py in s["points"]], dtype=np.int32)
|
|
overlay = img.copy()
|
|
cv2.fillPoly(overlay, [pts], color)
|
|
cv2.addWeighted(overlay, 0.30, img, 0.70, 0, img)
|
|
cv2.polylines(img, [pts], True, color, 2)
|
|
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
|
|
name = cat["name"] if cat else "?"
|
|
cv2.putText(img, name, (cx, cy), font, 0.7, color, 2, cv2.LINE_AA)
|
|
|
|
cv2.rectangle(img, (x0, y0), (x1, y1), (255, 200, 0), 2)
|
|
return len(shapes)
|
|
|
|
|
|
def render(image_path: Path, label_path: Path, output_path: Path,
|
|
selected_indices=None, branch_radius: int = 10, class_ids=None,
|
|
subdetect: bool = False, subdetect_polys: set = None):
|
|
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
|
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
|
img_orig = img.copy()
|
|
H, W = img.shape[:2]
|
|
|
|
text = label_path.read_text(encoding="utf-8").strip() if label_path.exists() else ""
|
|
if class_ids is not None:
|
|
text = "\n".join(
|
|
ln for ln in text.splitlines()
|
|
if ln.split() and int(ln.split()[0]) in class_ids
|
|
)
|
|
all_polygons = []
|
|
for line in text.splitlines():
|
|
parts = line.split()
|
|
if not parts:
|
|
continue
|
|
coords = list(map(float, parts[1:]))
|
|
pts = np.array(
|
|
[[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)],
|
|
dtype=np.int32,
|
|
)
|
|
if len(pts) >= 3:
|
|
all_polygons.append(pts)
|
|
|
|
if selected_indices is not None:
|
|
keep_idx = [i for i in sorted(selected_indices) if 0 <= i < len(all_polygons)]
|
|
else:
|
|
keep_idx = list(range(len(all_polygons)))
|
|
polygons = [all_polygons[i] for i in keep_idx]
|
|
print(f" total polygons in label: {len(all_polygons)}, "
|
|
f"selected: {len(polygons)} (indices: {keep_idx})")
|
|
|
|
# 수직/수평 판별: 장축 vs 이미지 중심→폴리곤 중심 radial 방향 비교
|
|
# 드론 사선 촬영에서 수직 기둥은 이미지 중심에서 방사형으로 기울어져 보임
|
|
img_cx, img_cy = W / 2.0, H / 2.0
|
|
|
|
def _poly_orient(pts):
|
|
rect = cv2.minAreaRect(pts)
|
|
(rx, ry), (rw, rh), angle = rect
|
|
if min(rw, rh) < 1:
|
|
return (160, 160, 160), '?', 0.0
|
|
ar = max(rw, rh) / min(rw, rh)
|
|
if ar < 1.3:
|
|
return (160, 160, 160), '?', 0.0
|
|
# 장축 방향 벡터
|
|
long_angle_deg = angle if rw >= rh else angle + 90
|
|
lx = np.cos(np.radians(long_angle_deg))
|
|
ly = np.sin(np.radians(long_angle_deg))
|
|
# radial 방향: 이미지 중심 → 폴리곤 중심
|
|
rdx, rdy = rx - img_cx, ry - img_cy
|
|
radial_norm = (rdx**2 + rdy**2) ** 0.5
|
|
if radial_norm < 1:
|
|
return (160, 160, 160), '?', 0.0
|
|
rdx, rdy = rdx / radial_norm, rdy / radial_norm
|
|
# 장축과 radial 방향의 정렬도 (|cos θ|)
|
|
cos_sim = abs(lx * rdx + ly * rdy)
|
|
# cos_sim → 1: 장축이 radial 방향 = 수직 기둥
|
|
# cos_sim → 0: 장축이 radial ⊥ 방향 = 수평 빔
|
|
if cos_sim > 0.7:
|
|
return (255, 80, 0), 'V', cos_sim
|
|
else:
|
|
return (0, 80, 255), 'H', cos_sim
|
|
|
|
orient_map = {} # orig_idx → 'V'/'H'/'?'
|
|
print("\n [폴리곤 수직/수평 분류]")
|
|
for orig_idx, pts in zip(keep_idx, polygons):
|
|
color, orient, cos_sim = _poly_orient(pts)
|
|
orient_map[orig_idx] = orient
|
|
print(f" poly {orig_idx:>2d}: {orient} cos_sim={cos_sim:.3f}")
|
|
overlay = img.copy()
|
|
cv2.fillPoly(overlay, [pts], color)
|
|
cv2.addWeighted(overlay, 0.25, img, 0.75, 0, img)
|
|
cv2.polylines(img, [pts], True, color, 2)
|
|
cx = int(pts[:, 0].mean())
|
|
cy = int(pts[:, 1].mean())
|
|
label = f"{orig_idx}{orient}"
|
|
cv2.putText(img, label, (cx, cy),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 5)
|
|
cv2.putText(img, label, (cx, cy),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)
|
|
|
|
# 모든 폴리곤 합친 마스크 → skeleton
|
|
merged = np.zeros((H, W), dtype=np.uint8)
|
|
for pts in polygons:
|
|
cv2.fillPoly(merged, [pts], 255)
|
|
# 미세한 픽셀 overlap으로 다른 구조가 합쳐지는 것 방지: 작은 erosion
|
|
# 5px dilation으로 인접 폴리곤 gap 브리지 후 3px erosion으로 복원
|
|
merged_dilated = cv2.dilate(merged, np.ones((11, 11), dtype=np.uint8))
|
|
merged_eroded = cv2.erode(merged_dilated, np.ones((7, 7), dtype=np.uint8))
|
|
skel = cv2.ximgproc.thinning(merged_eroded)
|
|
|
|
# skeleton 두껍게 (시인성)
|
|
skel_thick = cv2.dilate(skel, np.ones((5, 5), dtype=np.uint8))
|
|
img[skel_thick > 0] = (0, 0, 0)
|
|
|
|
# spur pruning: 짧은 dead-end arm 제거 → 진짜 branch만 남음
|
|
_box3 = np.ones((3, 3), dtype=np.float32)
|
|
|
|
def _prune_spurs(sk, min_arm_px):
|
|
"""branch pixel 제거 후 arm component 분석, 짧은 arm 제거."""
|
|
s_i = (sk > 0).astype(np.int32)
|
|
sp2 = np.pad(s_i, 1, mode='constant')
|
|
h2, w2 = s_i.shape
|
|
n2 = [sp2[0:h2, 0:w2], sp2[0:h2, 1:w2+1], sp2[0:h2, 2:w2+2],
|
|
sp2[1:h2+1, 2:w2+2], sp2[2:h2+2, 2:w2+2], sp2[2:h2+2, 1:w2+1],
|
|
sp2[2:h2+2, 0:w2], sp2[1:h2+1, 0:w2]]
|
|
cn2 = np.zeros((h2, w2), dtype=np.int32)
|
|
for i2 in range(8):
|
|
cn2 += (1 - n2[i2]) * n2[(i2 + 1) % 8]
|
|
cn2 *= s_i
|
|
arm_sk = sk.copy()
|
|
arm_sk[cn2 >= 3] = 0 # branch pixel 제거
|
|
n_arm, arm_lbl, arm_stats, arm_cents = cv2.connectedComponentsWithStats(arm_sk)
|
|
sk_out = sk.copy()
|
|
spur_cs = []
|
|
for a_id in range(1, n_arm):
|
|
if arm_stats[a_id, cv2.CC_STAT_AREA] < min_arm_px:
|
|
sk_out[arm_lbl == a_id] = 0
|
|
spur_cs.append((int(arm_cents[a_id, 0]), int(arm_cents[a_id, 1])))
|
|
return sk_out, spur_cs
|
|
|
|
skel_pruned, spur_centroids = _prune_spurs(skel, min_arm_px=30)
|
|
print(f" spur pruning: {len(spur_centroids)}개 arm 제거 (min_arm_px=30)")
|
|
|
|
skel_pf = (skel_pruned > 0).astype(np.float32)
|
|
nbr_p = cv2.filter2D(skel_pf, cv2.CV_32F, _box3) - skel_pf
|
|
deg_p = (nbr_p * skel_pf).astype(np.int32)
|
|
|
|
branch_mask_global = ((deg_p >= 3) & (skel_pruned > 0)).astype(np.uint8) * 255
|
|
endpoint_mask_global = ((deg_p == 1) & (skel_pruned > 0)).astype(np.uint8) * 255
|
|
|
|
def cluster_centroids(mask, cluster_radius):
|
|
if not mask.any():
|
|
return []
|
|
dilated = cv2.dilate(mask, np.ones((cluster_radius, cluster_radius), np.uint8))
|
|
num, _, _, cents = cv2.connectedComponentsWithStats(dilated)
|
|
return [(int(cents[i, 0]), int(cents[i, 1])) for i in range(1, num)]
|
|
|
|
def extreme_endpoint_pair(eps):
|
|
if len(eps) < 2:
|
|
return None
|
|
best, best_d2 = None, -1
|
|
for i in range(len(eps)):
|
|
for j in range(i + 1, len(eps)):
|
|
dx = eps[i][0] - eps[j][0]
|
|
dy = eps[i][1] - eps[j][1]
|
|
d2 = dx * dx + dy * dy
|
|
if d2 > best_d2:
|
|
best_d2 = d2
|
|
best = (eps[i], eps[j])
|
|
return best
|
|
|
|
# 그룹핑 = pruned skeleton의 connected component
|
|
num_comp, comp_labels = cv2.connectedComponents(skel_pruned)
|
|
|
|
# 진단: 각 polygon이 어느 component에 속하는지 + poly_idx→cids 역매핑 빌드
|
|
print(f"\n [Diagnostic] skeleton components: {num_comp - 1}")
|
|
poly_to_cids: dict = {}
|
|
for i, pts in enumerate(polygons):
|
|
poly_mask = np.zeros((H, W), dtype=np.uint8)
|
|
cv2.fillPoly(poly_mask, [pts], 255)
|
|
in_skel = (skel_pruned > 0) & (poly_mask > 0)
|
|
if in_skel.any():
|
|
cids = sorted(set(int(c) for c in np.unique(comp_labels[in_skel]) if c > 0))
|
|
poly_to_cids[i] = cids
|
|
else:
|
|
print(f" poly {i:>2d}: (skeleton 없음 — 너무 작거나 erosion됨)")
|
|
|
|
# 역매핑: component → polygon 목록
|
|
cid_to_polys: dict = {}
|
|
for poly_i, cids in poly_to_cids.items():
|
|
for cid in cids:
|
|
cid_to_polys.setdefault(cid, []).append(poly_i)
|
|
|
|
print(f"\n [그룹 → 폴리곤 목록] (전체 {len(cid_to_polys)}그룹)")
|
|
for cid in sorted(cid_to_polys.keys()):
|
|
n_px = int((comp_labels == cid).sum())
|
|
polys = cid_to_polys[cid]
|
|
orients = [f"{pi}({orient_map.get(pi, '?')})" for pi in polys]
|
|
has_v = any(orient_map.get(pi) == 'V' for pi in polys)
|
|
has_h = any(orient_map.get(pi) == 'H' for pi in polys)
|
|
tag = " ★라멘?" if (has_v and has_h) else ""
|
|
print(f" Comp {cid:>2d} ({n_px:>5d}px): {orients}{tag}")
|
|
print()
|
|
|
|
# subdetect_polys가 지정된 경우, 해당 poly들이 속한 component만 sub-detect
|
|
do_subdetect = subdetect or bool(subdetect_polys)
|
|
target_cids: set = set()
|
|
if subdetect_polys:
|
|
for pi in subdetect_polys:
|
|
for cid_ in poly_to_cids.get(pi, []):
|
|
target_cids.add(cid_)
|
|
print(f" sub-detect 대상 그룹: {sorted(target_cids)} (poly {sorted(subdetect_polys)} 기준)\n")
|
|
|
|
# cid → 해당 component에 속하는 polygon pts 목록 (subdetect용)
|
|
comp_to_polys: dict = {}
|
|
if do_subdetect:
|
|
for pts in polygons:
|
|
pmask = np.zeros((H, W), dtype=np.uint8)
|
|
cv2.fillPoly(pmask, [pts], 255)
|
|
in_s = (skel_pruned > 0) & (pmask > 0)
|
|
if in_s.any():
|
|
for cid_ in set(int(c) for c in comp_labels[in_s] if c > 0):
|
|
comp_to_polys.setdefault(cid_, []).append(pts)
|
|
|
|
total_branches = 0
|
|
total_endpoints = 0
|
|
total_groups = 0
|
|
total_lines = 0
|
|
|
|
for cid in range(1, num_comp):
|
|
comp_mask = (comp_labels == cid)
|
|
if int(comp_mask.sum()) < 50:
|
|
continue
|
|
total_groups += 1
|
|
|
|
comp_skel = (skel_pruned * comp_mask.astype(np.uint8)).astype(np.uint8)
|
|
comp_branch = cv2.bitwise_and(branch_mask_global, comp_skel)
|
|
comp_endpoint = cv2.bitwise_and(endpoint_mask_global, comp_skel)
|
|
|
|
branches = cluster_centroids(comp_branch, 7)
|
|
endpoints = cluster_centroids(comp_endpoint, 5)
|
|
|
|
# Virtual endpoints: endpoint 3개인 component에만 적용 (라멘에서 1개 누락된 경우).
|
|
# 알고리즘:
|
|
# 1) skeleton bbox 계산 → 4 변 (top/bottom/left/right)
|
|
# 2) 각 endpoint를 가장 가까운 변에 배정
|
|
# 3) 4 변 중 endpoint 없는 "빠진 방향" 찾음
|
|
# 4) 빠진 방향에 따라 branch 중 극값(가장 X/Y가 작거나 큰)을 virtual endpoint로
|
|
virtual_endpoints = []
|
|
if len(endpoints) == 3 and branches:
|
|
ys_arr, xs_arr = np.where(comp_skel > 0)
|
|
if len(ys_arr) > 0:
|
|
xmin, xmax = int(xs_arr.min()), int(xs_arr.max())
|
|
ymin, ymax = int(ys_arr.min()), int(ys_arr.max())
|
|
# 각 endpoint를 가장 가까운 변에 배정
|
|
occupied = set()
|
|
for ex, ey in endpoints:
|
|
d_top = ey - ymin
|
|
d_bot = ymax - ey
|
|
d_left = ex - xmin
|
|
d_right = xmax - ex
|
|
side, _ = min(
|
|
(('top', d_top), ('bottom', d_bot),
|
|
('left', d_left), ('right', d_right)),
|
|
key=lambda x: x[1])
|
|
occupied.add(side)
|
|
# 빠진 방향들
|
|
all_sides = {'top', 'bottom', 'left', 'right'}
|
|
missing_sides = all_sides - occupied
|
|
# 각 빠진 방향마다 극값 branch 찾기
|
|
for ms in missing_sides:
|
|
if ms == 'right':
|
|
target = max(branches, key=lambda b: b[0])
|
|
elif ms == 'left':
|
|
target = min(branches, key=lambda b: b[0])
|
|
elif ms == 'top':
|
|
target = min(branches, key=lambda b: b[1])
|
|
elif ms == 'bottom':
|
|
target = max(branches, key=lambda b: b[1])
|
|
too_close = any((target[0] - px) ** 2 + (target[1] - py) ** 2 < 80 ** 2
|
|
for px, py in endpoints + virtual_endpoints)
|
|
if not too_close:
|
|
virtual_endpoints.append(target)
|
|
print(f" virtual endpoint: missing={list(missing_sides)}, "
|
|
f"added={virtual_endpoints}")
|
|
|
|
# 통합 endpoint 리스트 (real + virtual)
|
|
endpoints_all = endpoints + virtual_endpoints
|
|
total_branches += len(branches)
|
|
total_endpoints += len(endpoints_all)
|
|
|
|
# ── 새 알고리즘: 최상단 노드 기반 topology ──────────────────────
|
|
if len(endpoints_all) < 2:
|
|
pass
|
|
else:
|
|
def node_dist(a, b):
|
|
return ((a[0]-b[0])**2 + (a[1]-b[1])**2) ** 0.5
|
|
|
|
# 1. 최상단 노드 T (y 최솟값)
|
|
top_i = min(range(len(endpoints_all)),
|
|
key=lambda i: (endpoints_all[i][1], endpoints_all[i][0]))
|
|
T = endpoints_all[top_i]
|
|
other_idx = [i for i in range(len(endpoints_all)) if i != top_i]
|
|
|
|
connected_idx = {top_i}
|
|
lines_to_draw = []
|
|
|
|
if len(other_idx) == 1:
|
|
j = other_idx[0]
|
|
lines_to_draw.append((T, endpoints_all[j]))
|
|
connected_idx.add(j)
|
|
longer_i = j
|
|
else:
|
|
# 2. Left(min x), Right(max x) → T에서 이 둘에만 연결
|
|
left_i = min(other_idx, key=lambda i: endpoints_all[i][0])
|
|
right_i = max(other_idx, key=lambda i: endpoints_all[i][0])
|
|
lines_to_draw.append((T, endpoints_all[left_i]))
|
|
lines_to_draw.append((T, endpoints_all[right_i]))
|
|
connected_idx.update([left_i, right_i])
|
|
|
|
# 3. 긴 쪽 결정
|
|
d_l = node_dist(T, endpoints_all[left_i])
|
|
d_r = node_dist(T, endpoints_all[right_i])
|
|
longer_i = left_i if d_l >= d_r else right_i
|
|
|
|
# 4. 체이닝: 긴 쪽 끝점에서 y+ 방향 미연결 노드로 계속 연결
|
|
current_i = longer_i
|
|
while True:
|
|
cur = endpoints_all[current_i]
|
|
below = [(i, node_dist(cur, endpoints_all[i]))
|
|
for i in range(len(endpoints_all))
|
|
if i not in connected_idx and endpoints_all[i][1] > cur[1]]
|
|
if not below:
|
|
break
|
|
next_i = min(below, key=lambda x: x[1])[0]
|
|
lines_to_draw.append((cur, endpoints_all[next_i]))
|
|
connected_idx.add(next_i)
|
|
current_i = next_i
|
|
|
|
for p1, p2 in lines_to_draw:
|
|
cv2.line(img, p1, p2, (0, 255, 255), 5)
|
|
total_lines += 1
|
|
print(f" line {p1}→{p2}: {node_dist(p1,p2):.0f}px")
|
|
|
|
# 노드 표시 (line 위에 올라오게)
|
|
for cx, cy in branches:
|
|
cv2.circle(img, (cx, cy), 14, (0, 220, 0), 3) # 초록 중간 원 (진짜 T/X junction)
|
|
for cx, cy in endpoints:
|
|
cv2.circle(img, (cx, cy), 18, (0, 255, 0), 4)
|
|
# virtual endpoint = 오렌지 원
|
|
for cx, cy in virtual_endpoints:
|
|
cv2.circle(img, (cx, cy), 18, (0, 140, 255), 4)
|
|
|
|
should_sub = do_subdetect and cid in comp_to_polys
|
|
if should_sub and target_cids:
|
|
should_sub = cid in target_cids
|
|
if should_sub:
|
|
n = _subdetect_group(img, comp_to_polys[cid])
|
|
print(f" 그룹 {cid}: sub-detect {n}개")
|
|
|
|
# 제거된 spur arm centroids = 분홍 작은 원
|
|
for cx, cy in spur_centroids:
|
|
cv2.circle(img, (cx, cy), 6, (180, 120, 180), 2)
|
|
|
|
print(f" {len(polygons)} polygons, skeleton 픽셀 {int((skel > 0).sum())}")
|
|
print(f" 그룹(skeleton component) {total_groups}개")
|
|
print(f" 분기점 {total_branches}개, 끝점 {total_endpoints}개, 연결선 {total_lines}개")
|
|
|
|
# 그룹 시각화: comp별 색상으로 폴리곤 표시 (원본 이미지 위)
|
|
_PALETTE = [
|
|
(255, 80, 80), ( 80, 220, 80), ( 80, 120, 255), (220, 200, 50),
|
|
(220, 80, 220), ( 60, 220, 220), (255, 160, 60), (100, 200, 120),
|
|
(180, 80, 255), (255, 140, 100),
|
|
]
|
|
_poly2comp = {pi: cid for cid, pis in cid_to_polys.items() for pi in pis}
|
|
_idx2pos = {orig: i for i, orig in enumerate(keep_idx)}
|
|
|
|
img = img_orig.copy()
|
|
for orig_idx, pts in zip(keep_idx, polygons):
|
|
cid = _poly2comp.get(orig_idx)
|
|
col = _PALETTE[(cid - 1) % len(_PALETTE)] if cid else (140, 140, 140)
|
|
ov = img.copy()
|
|
cv2.fillPoly(ov, [pts], col)
|
|
cv2.addWeighted(ov, 0.35, img, 0.65, 0, img)
|
|
cv2.polylines(img, [pts], True, col, 3)
|
|
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
|
|
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4)
|
|
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2)
|
|
|
|
for cid, pis in cid_to_polys.items():
|
|
col = _PALETTE[(cid - 1) % len(_PALETTE)]
|
|
valid = [_idx2pos[pi] for pi in pis if pi in _idx2pos]
|
|
if not valid:
|
|
continue
|
|
cx = int(sum(int(polygons[i][:, 0].mean()) for i in valid) / len(valid))
|
|
cy = int(sum(int(polygons[i][:, 1].mean()) for i in valid) / len(valid))
|
|
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 7)
|
|
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, col, 3)
|
|
|
|
h, w = img.shape[:2]
|
|
scale = min(1.0, 4096 / max(h, w))
|
|
if scale < 1.0:
|
|
img = cv2.resize(img, (int(w * scale), int(h * scale)))
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path))
|
|
print(f" → {output_path}")
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--image", required=True)
|
|
ap.add_argument("--label", required=True)
|
|
ap.add_argument("--output", required=True)
|
|
ap.add_argument("--polygons", type=str, default="",
|
|
help="comma-separated polygon indices (예: '10,11,12,26,27'). 비우면 전체")
|
|
ap.add_argument("--class-ids", type=str, default="",
|
|
help="포함할 클래스 ID (예: '1' or '1,3'). 비우면 전체")
|
|
ap.add_argument("--branch-radius", type=int, default=10,
|
|
help="branch 판정 원 반지름 px (기본 10)")
|
|
ap.add_argument("--subdetect", action="store_true",
|
|
help="모든 그룹 sub-detect (SAM3 서버 필요)")
|
|
ap.add_argument("--subdetect-polys", type=str, default="",
|
|
help="지정 polygon이 속한 그룹만 sub-detect (예: '2,27')")
|
|
args = ap.parse_args()
|
|
selected = None
|
|
if args.polygons.strip():
|
|
selected = set(int(x.strip()) for x in args.polygons.split(',') if x.strip())
|
|
class_ids = None
|
|
if args.class_ids.strip():
|
|
class_ids = set(int(x.strip()) for x in args.class_ids.split(',') if x.strip())
|
|
subdetect_polys = None
|
|
if args.subdetect_polys.strip():
|
|
subdetect_polys = set(int(x.strip()) for x in args.subdetect_polys.split(',') if x.strip())
|
|
render(Path(args.image), Path(args.label), Path(args.output), selected,
|
|
branch_radius=args.branch_radius, class_ids=class_ids,
|
|
subdetect=args.subdetect, subdetect_polys=subdetect_polys)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|