185
tools/sam3_segment_everything.py
Normal file
185
tools/sam3_segment_everything.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add server to path so we can import sam3 locally
|
||||
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
|
||||
models_path = os.path.join(server_path, "app", "models")
|
||||
if server_path not in sys.path:
|
||||
sys.path.insert(0, server_path)
|
||||
if models_path not in sys.path:
|
||||
sys.path.insert(0, models_path)
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from app.models.segment_anything_3 import SegmentAnything3
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
|
||||
offset = 1.0 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
|
||||
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
|
||||
return grid
|
||||
|
||||
def mask_iou(mask1, mask2):
|
||||
inter = np.logical_and(mask1, mask2).sum()
|
||||
union = np.logical_or(mask1, mask2).sum()
|
||||
if union == 0:
|
||||
return 0
|
||||
return inter / union
|
||||
|
||||
def mask_to_polygon(mask, epsilon_factor=0.001):
|
||||
mask = np.squeeze(mask)
|
||||
mask_uint8 = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return []
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
if epsilon_factor > 0:
|
||||
epsilon = epsilon_factor * cv2.arcLength(largest, True)
|
||||
approx = cv2.approxPolyDP(largest, epsilon, True)
|
||||
else:
|
||||
approx = largest
|
||||
points = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
||||
return points
|
||||
|
||||
def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5):
|
||||
print("Loading SAM3 Model locally...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
model = build_sam3_image_model(
|
||||
bpe_path=bpe_path,
|
||||
device=device,
|
||||
checkpoint_path=model_path,
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device)
|
||||
|
||||
# PIL image format
|
||||
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
from PIL import Image
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
print("Computing image embedding...")
|
||||
t0 = time.time()
|
||||
state = processor.set_image(pil_image)
|
||||
print(f"Image embedding done in {time.time() - t0:.2f}s")
|
||||
|
||||
grid_points = build_point_grid(points_per_side)
|
||||
print(f"Generated {len(grid_points)} grid points for sampling.")
|
||||
|
||||
masks = []
|
||||
scores = []
|
||||
|
||||
t0 = time.time()
|
||||
for i, (nx, ny) in enumerate(grid_points):
|
||||
if i % 100 == 0:
|
||||
print(f" Processed {i}/{len(grid_points)} points...")
|
||||
|
||||
processor.reset_all_prompts(state)
|
||||
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
|
||||
|
||||
if "masks" in state and len(state["masks"]) > 0:
|
||||
# Take the mask with the highest score
|
||||
best_idx = torch.argmax(state["scores"])
|
||||
mask = state["masks"][best_idx].cpu().numpy()
|
||||
score = state["scores"][best_idx].item()
|
||||
if score > conf_thresh:
|
||||
masks.append(mask)
|
||||
scores.append(score)
|
||||
print(f"Grid prediction done in {time.time() - t0:.2f}s")
|
||||
print(f"Found {len(masks)} raw masks.")
|
||||
|
||||
if not masks:
|
||||
return []
|
||||
|
||||
# Simple NMS based on IoU
|
||||
print("Applying NMS...")
|
||||
order = np.argsort(scores)[::-1]
|
||||
keep = []
|
||||
|
||||
for idx in order:
|
||||
if len(keep) == 0:
|
||||
keep.append(idx)
|
||||
continue
|
||||
|
||||
current_mask = masks[idx]
|
||||
overlap = False
|
||||
for k in keep:
|
||||
iou = mask_iou(current_mask, masks[k])
|
||||
if iou > nms_thresh:
|
||||
overlap = True
|
||||
break
|
||||
if not overlap:
|
||||
keep.append(idx)
|
||||
|
||||
final_masks = [masks[idx] for idx in keep]
|
||||
final_scores = [scores[idx] for idx in keep]
|
||||
print(f"Kept {len(final_masks)} masks after NMS.")
|
||||
|
||||
# Convert to polygons
|
||||
results = []
|
||||
for m, s in zip(final_masks, final_scores):
|
||||
poly = mask_to_polygon(m)
|
||||
if poly:
|
||||
results.append({"polygon": poly, "score": s})
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True, help="Input image path")
|
||||
parser.add_argument("--output", required=True, help="Output vis image path")
|
||||
parser.add_argument("--points", type=int, default=32, help="Points per side")
|
||||
args = parser.parse_args()
|
||||
|
||||
buf = np.fromfile(args.input, dtype=np.uint8)
|
||||
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
print(f"Failed to read {args.input}")
|
||||
return
|
||||
|
||||
# Shrink image if too large just to make NMS faster
|
||||
h, w = image.shape[:2]
|
||||
max_dim = 1024
|
||||
if max(h, w) > max_dim:
|
||||
scale = max_dim / max(h, w)
|
||||
image_proc = cv2.resize(image, (int(w * scale), int(h * scale)))
|
||||
else:
|
||||
image_proc = image.copy()
|
||||
|
||||
model_path = os.path.join(server_path, "sam3.pt")
|
||||
|
||||
results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7)
|
||||
|
||||
vis = image_proc.copy()
|
||||
np.random.seed(42)
|
||||
for res in results:
|
||||
poly = res["polygon"]
|
||||
pts = np.array(poly, dtype=np.int32)
|
||||
color = np.random.randint(0, 255, (3,)).tolist()
|
||||
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 1)
|
||||
|
||||
# Fix unicode paths in output
|
||||
is_success, im_buf_arr = cv2.imencode(".jpg", vis)
|
||||
if is_success:
|
||||
im_buf_arr.tofile(args.output)
|
||||
print(f"Saved visualization to {args.output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user