Files
railway-client/tools/sam3_segment_everything.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

186 lines
6.1 KiB
Python

import argparse
import sys
import os
import time
from pathlib import Path
import cv2
import numpy as np
import torch
# Add server to path so we can import sam3 locally
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
models_path = os.path.join(server_path, "app", "models")
if server_path not in sys.path:
sys.path.insert(0, server_path)
if models_path not in sys.path:
sys.path.insert(0, models_path)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from app.models.segment_anything_3 import SegmentAnything3
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
offset = 1.0 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
return grid
def mask_iou(mask1, mask2):
inter = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0:
return 0
return inter / union
def mask_to_polygon(mask, epsilon_factor=0.001):
mask = np.squeeze(mask)
mask_uint8 = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return []
largest = max(contours, key=cv2.contourArea)
if epsilon_factor > 0:
epsilon = epsilon_factor * cv2.arcLength(largest, True)
approx = cv2.approxPolyDP(largest, epsilon, True)
else:
approx = largest
points = [[float(p[0][0]), float(p[0][1])] for p in approx]
return points
def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5):
print("Loading SAM3 Model locally...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
model = build_sam3_image_model(
bpe_path=bpe_path,
device=device,
checkpoint_path=model_path,
)
if device == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device)
# PIL image format
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
from PIL import Image
pil_image = Image.fromarray(image_rgb)
print("Computing image embedding...")
t0 = time.time()
state = processor.set_image(pil_image)
print(f"Image embedding done in {time.time() - t0:.2f}s")
grid_points = build_point_grid(points_per_side)
print(f"Generated {len(grid_points)} grid points for sampling.")
masks = []
scores = []
t0 = time.time()
for i, (nx, ny) in enumerate(grid_points):
if i % 100 == 0:
print(f" Processed {i}/{len(grid_points)} points...")
processor.reset_all_prompts(state)
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
if "masks" in state and len(state["masks"]) > 0:
# Take the mask with the highest score
best_idx = torch.argmax(state["scores"])
mask = state["masks"][best_idx].cpu().numpy()
score = state["scores"][best_idx].item()
if score > conf_thresh:
masks.append(mask)
scores.append(score)
print(f"Grid prediction done in {time.time() - t0:.2f}s")
print(f"Found {len(masks)} raw masks.")
if not masks:
return []
# Simple NMS based on IoU
print("Applying NMS...")
order = np.argsort(scores)[::-1]
keep = []
for idx in order:
if len(keep) == 0:
keep.append(idx)
continue
current_mask = masks[idx]
overlap = False
for k in keep:
iou = mask_iou(current_mask, masks[k])
if iou > nms_thresh:
overlap = True
break
if not overlap:
keep.append(idx)
final_masks = [masks[idx] for idx in keep]
final_scores = [scores[idx] for idx in keep]
print(f"Kept {len(final_masks)} masks after NMS.")
# Convert to polygons
results = []
for m, s in zip(final_masks, final_scores):
poly = mask_to_polygon(m)
if poly:
results.append({"polygon": poly, "score": s})
return results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output", required=True, help="Output vis image path")
parser.add_argument("--points", type=int, default=32, help="Points per side")
args = parser.parse_args()
buf = np.fromfile(args.input, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to read {args.input}")
return
# Shrink image if too large just to make NMS faster
h, w = image.shape[:2]
max_dim = 1024
if max(h, w) > max_dim:
scale = max_dim / max(h, w)
image_proc = cv2.resize(image, (int(w * scale), int(h * scale)))
else:
image_proc = image.copy()
model_path = os.path.join(server_path, "sam3.pt")
results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7)
vis = image_proc.copy()
np.random.seed(42)
for res in results:
poly = res["polygon"]
pts = np.array(poly, dtype=np.int32)
color = np.random.randint(0, 255, (3,)).tolist()
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# Fix unicode paths in output
is_success, im_buf_arr = cv2.imencode(".jpg", vis)
if is_success:
im_buf_arr.tofile(args.output)
print(f"Saved visualization to {args.output}")
if __name__ == "__main__":
main()