237
tools/video_sam3_segment.py
Normal file
237
tools/video_sam3_segment.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션
|
||||
Usage:
|
||||
1. SAM3 서버 시작: start_server.bat
|
||||
2. python tools/video_sam3_segment.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# === 설정 ===
|
||||
VIDEO_PATH = Path("sample/rail.mp4")
|
||||
OUTPUT_DIR = Path("output/video_segmentation")
|
||||
SAM3_URL = "http://localhost:8000"
|
||||
MODEL_ID = "segment_anything_3"
|
||||
FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격
|
||||
|
||||
# 철도 시설물 + 일반 객체 프롬프트
|
||||
PROMPTS = [
|
||||
"catenary pole", # 전철주
|
||||
"junction box", # 전기박스
|
||||
"utility box", # 통신박스
|
||||
"rail track", # 레일
|
||||
"fence", # 펜스
|
||||
"cable", # 전선/케이블
|
||||
"sign", # 표지판
|
||||
"building", # 건물
|
||||
"vegetation", # 식생
|
||||
]
|
||||
|
||||
# 클래스별 색상 (BGR)
|
||||
COLORS = {
|
||||
"catenary pole": (0, 0, 255), # 빨강
|
||||
"junction box": (0, 165, 255), # 주황
|
||||
"utility box": (0, 255, 255), # 노랑
|
||||
"rail track": (255, 0, 0), # 파랑
|
||||
"fence": (255, 0, 255), # 자홍
|
||||
"cable": (0, 255, 0), # 초록
|
||||
"sign": (255, 255, 0), # 시안
|
||||
"building": (128, 128, 255), # 연한 빨강
|
||||
"vegetation": (0, 128, 0), # 진한 초록
|
||||
}
|
||||
|
||||
|
||||
def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]:
|
||||
"""동영상에서 일정 간격으로 프레임 추출"""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
print(f"ERROR: Cannot open video: {video_path}")
|
||||
sys.exit(1)
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s")
|
||||
|
||||
frames = []
|
||||
idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if idx % interval == 0:
|
||||
frames.append((idx, frame))
|
||||
idx += 1
|
||||
|
||||
cap.release()
|
||||
print(f"Extracted {len(frames)} frames (interval={interval})")
|
||||
return frames
|
||||
|
||||
|
||||
def encode_frame(frame: np.ndarray) -> str:
|
||||
"""프레임을 base64로 인코딩"""
|
||||
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
def predict_sam3(image_b64: str, text_prompt: str) -> dict:
|
||||
"""SAM3 서버에 예측 요청"""
|
||||
payload = {
|
||||
"model": MODEL_ID,
|
||||
"image": image_b64,
|
||||
"params": {
|
||||
"text_prompt": text_prompt,
|
||||
"conf_threshold": 0.25,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f" ERROR [{text_prompt}]: {e}")
|
||||
return {"success": False}
|
||||
|
||||
|
||||
def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray:
|
||||
"""세그멘테이션 결과를 프레임 위에 그리기"""
|
||||
overlay = frame.copy()
|
||||
color = COLORS.get(prompt, (200, 200, 200))
|
||||
|
||||
for shape in shapes:
|
||||
points = np.array(shape["points"], dtype=np.int32)
|
||||
shape_type = shape.get("shape_type", "polygon")
|
||||
|
||||
if shape_type == "polygon" and len(points) >= 3:
|
||||
cv2.fillPoly(overlay, [points], color)
|
||||
cv2.polylines(frame, [points], True, color, 2)
|
||||
elif shape_type == "rectangle" and len(points) == 2:
|
||||
cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1)
|
||||
cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2)
|
||||
|
||||
# 라벨 텍스트
|
||||
label = shape.get("label", prompt)
|
||||
score = shape.get("score")
|
||||
text = f"{label}" + (f" {score:.2f}" if score else "")
|
||||
if len(points) > 0:
|
||||
tx, ty = int(points[0][0]), int(points[0][1]) - 5
|
||||
cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
# 반투명 오버레이 블렌딩
|
||||
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
||||
return frame
|
||||
|
||||
|
||||
def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray:
|
||||
"""클래스별 개별 레이어 이미지 (투명 배경 위 마스크)"""
|
||||
h, w = frame_shape[:2]
|
||||
color = COLORS.get(prompt, (200, 200, 200))
|
||||
bgr = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
alpha = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for shape in shapes:
|
||||
points = np.array(shape["points"], dtype=np.int32)
|
||||
if shape.get("shape_type") == "polygon" and len(points) >= 3:
|
||||
cv2.fillPoly(bgr, [points], color)
|
||||
cv2.fillPoly(alpha, [points], 180)
|
||||
|
||||
layer = np.dstack([bgr, alpha])
|
||||
return layer
|
||||
|
||||
|
||||
def main():
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. 서버 상태 확인
|
||||
print("=== SAM3 서버 연결 확인 ===")
|
||||
try:
|
||||
health = requests.get(f"{SAM3_URL}/health", timeout=5)
|
||||
print(f"Server status: {health.json()}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("ERROR: SAM3 서버가 실행 중이 아닙니다!")
|
||||
print(" → start_server.bat을 먼저 실행하세요.")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. 프레임 추출
|
||||
print("\n=== 프레임 추출 ===")
|
||||
frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL)
|
||||
|
||||
# 3. 각 프레임별, 각 프롬프트별 세그멘테이션
|
||||
all_results = {}
|
||||
|
||||
for frame_idx, (fidx, frame) in enumerate(frames):
|
||||
print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===")
|
||||
|
||||
frame_b64 = encode_frame(frame)
|
||||
frame_results = {}
|
||||
composite = frame.copy()
|
||||
|
||||
for prompt in PROMPTS:
|
||||
print(f" Segmenting: {prompt}...", end=" ")
|
||||
result = predict_sam3(frame_b64, prompt)
|
||||
|
||||
if result.get("success") and result.get("data", {}).get("shapes"):
|
||||
shapes = result["data"]["shapes"]
|
||||
n = len(shapes)
|
||||
print(f"→ {n} objects found")
|
||||
frame_results[prompt] = shapes
|
||||
|
||||
# 합성 이미지에 그리기
|
||||
composite = draw_shapes_on_frame(composite, shapes, prompt)
|
||||
|
||||
# 개별 레이어 저장 (PNG with alpha)
|
||||
layer = create_layer_image(frame.shape, shapes, prompt)
|
||||
layer_name = prompt.replace(" ", "_")
|
||||
layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png"
|
||||
cv2.imwrite(str(layer_path), layer)
|
||||
else:
|
||||
print("→ no objects")
|
||||
|
||||
# 합성 이미지 저장
|
||||
composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg"
|
||||
cv2.imwrite(str(composite_path), composite)
|
||||
|
||||
# 원본 프레임 저장 (비교용)
|
||||
original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg"
|
||||
cv2.imwrite(str(original_path), frame)
|
||||
|
||||
all_results[f"frame_{fidx}"] = {
|
||||
"frame_index": fidx,
|
||||
"detections": {
|
||||
prompt: len(shapes)
|
||||
for prompt, shapes in frame_results.items()
|
||||
},
|
||||
}
|
||||
|
||||
# 4. 결과 요약 저장
|
||||
summary_path = OUTPUT_DIR / "segmentation_summary.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 5. 범례 이미지 생성
|
||||
legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8)
|
||||
for i, (prompt, color) in enumerate(COLORS.items()):
|
||||
y = i * 30 + 20
|
||||
cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1)
|
||||
cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend)
|
||||
|
||||
print(f"\n=== 완료 ===")
|
||||
print(f"결과 저장: {OUTPUT_DIR}/")
|
||||
print(f" - frame_XXXX_original.jpg : 원본 프레임")
|
||||
print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성")
|
||||
print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)")
|
||||
print(f" - segmentation_summary.json : 검출 요약")
|
||||
print(f" - legend.jpg : 색상 범례")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user