238 lines
8.3 KiB
Python
238 lines
8.3 KiB
Python
"""
|
|
드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션
|
|
Usage:
|
|
1. SAM3 서버 시작: start_server.bat
|
|
2. python tools/video_sam3_segment.py
|
|
"""
|
|
|
|
import base64
|
|
import cv2
|
|
import json
|
|
import numpy as np
|
|
import requests
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# === 설정 ===
|
|
VIDEO_PATH = Path("sample/rail.mp4")
|
|
OUTPUT_DIR = Path("output/video_segmentation")
|
|
SAM3_URL = "http://localhost:8000"
|
|
MODEL_ID = "segment_anything_3"
|
|
FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격
|
|
|
|
# 철도 시설물 + 일반 객체 프롬프트
|
|
PROMPTS = [
|
|
"catenary pole", # 전철주
|
|
"junction box", # 전기박스
|
|
"utility box", # 통신박스
|
|
"rail track", # 레일
|
|
"fence", # 펜스
|
|
"cable", # 전선/케이블
|
|
"sign", # 표지판
|
|
"building", # 건물
|
|
"vegetation", # 식생
|
|
]
|
|
|
|
# 클래스별 색상 (BGR)
|
|
COLORS = {
|
|
"catenary pole": (0, 0, 255), # 빨강
|
|
"junction box": (0, 165, 255), # 주황
|
|
"utility box": (0, 255, 255), # 노랑
|
|
"rail track": (255, 0, 0), # 파랑
|
|
"fence": (255, 0, 255), # 자홍
|
|
"cable": (0, 255, 0), # 초록
|
|
"sign": (255, 255, 0), # 시안
|
|
"building": (128, 128, 255), # 연한 빨강
|
|
"vegetation": (0, 128, 0), # 진한 초록
|
|
}
|
|
|
|
|
|
def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]:
|
|
"""동영상에서 일정 간격으로 프레임 추출"""
|
|
cap = cv2.VideoCapture(str(video_path))
|
|
if not cap.isOpened():
|
|
print(f"ERROR: Cannot open video: {video_path}")
|
|
sys.exit(1)
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s")
|
|
|
|
frames = []
|
|
idx = 0
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
if idx % interval == 0:
|
|
frames.append((idx, frame))
|
|
idx += 1
|
|
|
|
cap.release()
|
|
print(f"Extracted {len(frames)} frames (interval={interval})")
|
|
return frames
|
|
|
|
|
|
def encode_frame(frame: np.ndarray) -> str:
|
|
"""프레임을 base64로 인코딩"""
|
|
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
return base64.b64encode(buffer).decode("utf-8")
|
|
|
|
|
|
def predict_sam3(image_b64: str, text_prompt: str) -> dict:
|
|
"""SAM3 서버에 예측 요청"""
|
|
payload = {
|
|
"model": MODEL_ID,
|
|
"image": image_b64,
|
|
"params": {
|
|
"text_prompt": text_prompt,
|
|
"conf_threshold": 0.25,
|
|
},
|
|
}
|
|
try:
|
|
resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
except requests.exceptions.ConnectionError:
|
|
print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f" ERROR [{text_prompt}]: {e}")
|
|
return {"success": False}
|
|
|
|
|
|
def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray:
|
|
"""세그멘테이션 결과를 프레임 위에 그리기"""
|
|
overlay = frame.copy()
|
|
color = COLORS.get(prompt, (200, 200, 200))
|
|
|
|
for shape in shapes:
|
|
points = np.array(shape["points"], dtype=np.int32)
|
|
shape_type = shape.get("shape_type", "polygon")
|
|
|
|
if shape_type == "polygon" and len(points) >= 3:
|
|
cv2.fillPoly(overlay, [points], color)
|
|
cv2.polylines(frame, [points], True, color, 2)
|
|
elif shape_type == "rectangle" and len(points) == 2:
|
|
cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1)
|
|
cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2)
|
|
|
|
# 라벨 텍스트
|
|
label = shape.get("label", prompt)
|
|
score = shape.get("score")
|
|
text = f"{label}" + (f" {score:.2f}" if score else "")
|
|
if len(points) > 0:
|
|
tx, ty = int(points[0][0]), int(points[0][1]) - 5
|
|
cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
|
|
# 반투명 오버레이 블렌딩
|
|
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
|
return frame
|
|
|
|
|
|
def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray:
|
|
"""클래스별 개별 레이어 이미지 (투명 배경 위 마스크)"""
|
|
h, w = frame_shape[:2]
|
|
color = COLORS.get(prompt, (200, 200, 200))
|
|
bgr = np.zeros((h, w, 3), dtype=np.uint8)
|
|
alpha = np.zeros((h, w), dtype=np.uint8)
|
|
|
|
for shape in shapes:
|
|
points = np.array(shape["points"], dtype=np.int32)
|
|
if shape.get("shape_type") == "polygon" and len(points) >= 3:
|
|
cv2.fillPoly(bgr, [points], color)
|
|
cv2.fillPoly(alpha, [points], 180)
|
|
|
|
layer = np.dstack([bgr, alpha])
|
|
return layer
|
|
|
|
|
|
def main():
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 1. 서버 상태 확인
|
|
print("=== SAM3 서버 연결 확인 ===")
|
|
try:
|
|
health = requests.get(f"{SAM3_URL}/health", timeout=5)
|
|
print(f"Server status: {health.json()}")
|
|
except requests.exceptions.ConnectionError:
|
|
print("ERROR: SAM3 서버가 실행 중이 아닙니다!")
|
|
print(" → start_server.bat을 먼저 실행하세요.")
|
|
sys.exit(1)
|
|
|
|
# 2. 프레임 추출
|
|
print("\n=== 프레임 추출 ===")
|
|
frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL)
|
|
|
|
# 3. 각 프레임별, 각 프롬프트별 세그멘테이션
|
|
all_results = {}
|
|
|
|
for frame_idx, (fidx, frame) in enumerate(frames):
|
|
print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===")
|
|
|
|
frame_b64 = encode_frame(frame)
|
|
frame_results = {}
|
|
composite = frame.copy()
|
|
|
|
for prompt in PROMPTS:
|
|
print(f" Segmenting: {prompt}...", end=" ")
|
|
result = predict_sam3(frame_b64, prompt)
|
|
|
|
if result.get("success") and result.get("data", {}).get("shapes"):
|
|
shapes = result["data"]["shapes"]
|
|
n = len(shapes)
|
|
print(f"→ {n} objects found")
|
|
frame_results[prompt] = shapes
|
|
|
|
# 합성 이미지에 그리기
|
|
composite = draw_shapes_on_frame(composite, shapes, prompt)
|
|
|
|
# 개별 레이어 저장 (PNG with alpha)
|
|
layer = create_layer_image(frame.shape, shapes, prompt)
|
|
layer_name = prompt.replace(" ", "_")
|
|
layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png"
|
|
cv2.imwrite(str(layer_path), layer)
|
|
else:
|
|
print("→ no objects")
|
|
|
|
# 합성 이미지 저장
|
|
composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg"
|
|
cv2.imwrite(str(composite_path), composite)
|
|
|
|
# 원본 프레임 저장 (비교용)
|
|
original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg"
|
|
cv2.imwrite(str(original_path), frame)
|
|
|
|
all_results[f"frame_{fidx}"] = {
|
|
"frame_index": fidx,
|
|
"detections": {
|
|
prompt: len(shapes)
|
|
for prompt, shapes in frame_results.items()
|
|
},
|
|
}
|
|
|
|
# 4. 결과 요약 저장
|
|
summary_path = OUTPUT_DIR / "segmentation_summary.json"
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
json.dump(all_results, f, indent=2, ensure_ascii=False)
|
|
|
|
# 5. 범례 이미지 생성
|
|
legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8)
|
|
for i, (prompt, color) in enumerate(COLORS.items()):
|
|
y = i * 30 + 20
|
|
cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1)
|
|
cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend)
|
|
|
|
print(f"\n=== 완료 ===")
|
|
print(f"결과 저장: {OUTPUT_DIR}/")
|
|
print(f" - frame_XXXX_original.jpg : 원본 프레임")
|
|
print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성")
|
|
print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)")
|
|
print(f" - segmentation_summary.json : 검출 요약")
|
|
print(f" - legend.jpg : 색상 범례")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|