Files
railway-client/tools/video_sam3_segment.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

238 lines
8.3 KiB
Python

"""
드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션
Usage:
1. SAM3 서버 시작: start_server.bat
2. python tools/video_sam3_segment.py
"""
import base64
import cv2
import json
import numpy as np
import requests
import sys
from pathlib import Path
# === 설정 ===
VIDEO_PATH = Path("sample/rail.mp4")
OUTPUT_DIR = Path("output/video_segmentation")
SAM3_URL = "http://localhost:8000"
MODEL_ID = "segment_anything_3"
FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격
# 철도 시설물 + 일반 객체 프롬프트
PROMPTS = [
"catenary pole", # 전철주
"junction box", # 전기박스
"utility box", # 통신박스
"rail track", # 레일
"fence", # 펜스
"cable", # 전선/케이블
"sign", # 표지판
"building", # 건물
"vegetation", # 식생
]
# 클래스별 색상 (BGR)
COLORS = {
"catenary pole": (0, 0, 255), # 빨강
"junction box": (0, 165, 255), # 주황
"utility box": (0, 255, 255), # 노랑
"rail track": (255, 0, 0), # 파랑
"fence": (255, 0, 255), # 자홍
"cable": (0, 255, 0), # 초록
"sign": (255, 255, 0), # 시안
"building": (128, 128, 255), # 연한 빨강
"vegetation": (0, 128, 0), # 진한 초록
}
def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]:
"""동영상에서 일정 간격으로 프레임 추출"""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
print(f"ERROR: Cannot open video: {video_path}")
sys.exit(1)
fps = cap.get(cv2.CAP_PROP_FPS)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s")
frames = []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if idx % interval == 0:
frames.append((idx, frame))
idx += 1
cap.release()
print(f"Extracted {len(frames)} frames (interval={interval})")
return frames
def encode_frame(frame: np.ndarray) -> str:
"""프레임을 base64로 인코딩"""
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
return base64.b64encode(buffer).decode("utf-8")
def predict_sam3(image_b64: str, text_prompt: str) -> dict:
"""SAM3 서버에 예측 요청"""
payload = {
"model": MODEL_ID,
"image": image_b64,
"params": {
"text_prompt": text_prompt,
"conf_threshold": 0.25,
},
}
try:
resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60)
resp.raise_for_status()
return resp.json()
except requests.exceptions.ConnectionError:
print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.")
sys.exit(1)
except Exception as e:
print(f" ERROR [{text_prompt}]: {e}")
return {"success": False}
def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray:
"""세그멘테이션 결과를 프레임 위에 그리기"""
overlay = frame.copy()
color = COLORS.get(prompt, (200, 200, 200))
for shape in shapes:
points = np.array(shape["points"], dtype=np.int32)
shape_type = shape.get("shape_type", "polygon")
if shape_type == "polygon" and len(points) >= 3:
cv2.fillPoly(overlay, [points], color)
cv2.polylines(frame, [points], True, color, 2)
elif shape_type == "rectangle" and len(points) == 2:
cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1)
cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2)
# 라벨 텍스트
label = shape.get("label", prompt)
score = shape.get("score")
text = f"{label}" + (f" {score:.2f}" if score else "")
if len(points) > 0:
tx, ty = int(points[0][0]), int(points[0][1]) - 5
cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 반투명 오버레이 블렌딩
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
return frame
def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray:
"""클래스별 개별 레이어 이미지 (투명 배경 위 마스크)"""
h, w = frame_shape[:2]
color = COLORS.get(prompt, (200, 200, 200))
bgr = np.zeros((h, w, 3), dtype=np.uint8)
alpha = np.zeros((h, w), dtype=np.uint8)
for shape in shapes:
points = np.array(shape["points"], dtype=np.int32)
if shape.get("shape_type") == "polygon" and len(points) >= 3:
cv2.fillPoly(bgr, [points], color)
cv2.fillPoly(alpha, [points], 180)
layer = np.dstack([bgr, alpha])
return layer
def main():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# 1. 서버 상태 확인
print("=== SAM3 서버 연결 확인 ===")
try:
health = requests.get(f"{SAM3_URL}/health", timeout=5)
print(f"Server status: {health.json()}")
except requests.exceptions.ConnectionError:
print("ERROR: SAM3 서버가 실행 중이 아닙니다!")
print(" → start_server.bat을 먼저 실행하세요.")
sys.exit(1)
# 2. 프레임 추출
print("\n=== 프레임 추출 ===")
frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL)
# 3. 각 프레임별, 각 프롬프트별 세그멘테이션
all_results = {}
for frame_idx, (fidx, frame) in enumerate(frames):
print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===")
frame_b64 = encode_frame(frame)
frame_results = {}
composite = frame.copy()
for prompt in PROMPTS:
print(f" Segmenting: {prompt}...", end=" ")
result = predict_sam3(frame_b64, prompt)
if result.get("success") and result.get("data", {}).get("shapes"):
shapes = result["data"]["shapes"]
n = len(shapes)
print(f"{n} objects found")
frame_results[prompt] = shapes
# 합성 이미지에 그리기
composite = draw_shapes_on_frame(composite, shapes, prompt)
# 개별 레이어 저장 (PNG with alpha)
layer = create_layer_image(frame.shape, shapes, prompt)
layer_name = prompt.replace(" ", "_")
layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png"
cv2.imwrite(str(layer_path), layer)
else:
print("→ no objects")
# 합성 이미지 저장
composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg"
cv2.imwrite(str(composite_path), composite)
# 원본 프레임 저장 (비교용)
original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg"
cv2.imwrite(str(original_path), frame)
all_results[f"frame_{fidx}"] = {
"frame_index": fidx,
"detections": {
prompt: len(shapes)
for prompt, shapes in frame_results.items()
},
}
# 4. 결과 요약 저장
summary_path = OUTPUT_DIR / "segmentation_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, ensure_ascii=False)
# 5. 범례 이미지 생성
legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8)
for i, (prompt, color) in enumerate(COLORS.items()):
y = i * 30 + 20
cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1)
cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend)
print(f"\n=== 완료 ===")
print(f"결과 저장: {OUTPUT_DIR}/")
print(f" - frame_XXXX_original.jpg : 원본 프레임")
print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성")
print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)")
print(f" - segmentation_summary.json : 검출 요약")
print(f" - legend.jpg : 색상 범례")
if __name__ == "__main__":
main()