import argparse import sys import os import time from pathlib import Path import cv2 import numpy as np import torch # Add server to path so we can import sam3 locally server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server")) models_path = os.path.join(server_path, "app", "models") if server_path not in sys.path: sys.path.insert(0, server_path) if models_path not in sys.path: sys.path.insert(0, models_path) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from app.models.segment_anything_3 import SegmentAnything3 def build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1].""" offset = 1.0 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) pts_x, pts_y = np.meshgrid(points_one_side, points_one_side) grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1) return grid def mask_iou(mask1, mask2): inter = np.logical_and(mask1, mask2).sum() union = np.logical_or(mask1, mask2).sum() if union == 0: return 0 return inter / union def mask_to_polygon(mask, epsilon_factor=0.001): mask = np.squeeze(mask) mask_uint8 = (mask > 0).astype(np.uint8) contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return [] largest = max(contours, key=cv2.contourArea) if epsilon_factor > 0: epsilon = epsilon_factor * cv2.arcLength(largest, True) approx = cv2.approxPolyDP(largest, epsilon, True) else: approx = largest points = [[float(p[0][0]), float(p[0][1])] for p in approx] return points def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5): print("Loading SAM3 Model locally...") device = "cuda" if torch.cuda.is_available() else "cpu" bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz") model = build_sam3_image_model( bpe_path=bpe_path, device=device, checkpoint_path=model_path, ) if device == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device) # PIL image format image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) from PIL import Image pil_image = Image.fromarray(image_rgb) print("Computing image embedding...") t0 = time.time() state = processor.set_image(pil_image) print(f"Image embedding done in {time.time() - t0:.2f}s") grid_points = build_point_grid(points_per_side) print(f"Generated {len(grid_points)} grid points for sampling.") masks = [] scores = [] t0 = time.time() for i, (nx, ny) in enumerate(grid_points): if i % 100 == 0: print(f" Processed {i}/{len(grid_points)} points...") processor.reset_all_prompts(state) state = processor.add_point_prompt(point=[nx, ny], label=True, state=state) if "masks" in state and len(state["masks"]) > 0: # Take the mask with the highest score best_idx = torch.argmax(state["scores"]) mask = state["masks"][best_idx].cpu().numpy() score = state["scores"][best_idx].item() if score > conf_thresh: masks.append(mask) scores.append(score) print(f"Grid prediction done in {time.time() - t0:.2f}s") print(f"Found {len(masks)} raw masks.") if not masks: return [] # Simple NMS based on IoU print("Applying NMS...") order = np.argsort(scores)[::-1] keep = [] for idx in order: if len(keep) == 0: keep.append(idx) continue current_mask = masks[idx] overlap = False for k in keep: iou = mask_iou(current_mask, masks[k]) if iou > nms_thresh: overlap = True break if not overlap: keep.append(idx) final_masks = [masks[idx] for idx in keep] final_scores = [scores[idx] for idx in keep] print(f"Kept {len(final_masks)} masks after NMS.") # Convert to polygons results = [] for m, s in zip(final_masks, final_scores): poly = mask_to_polygon(m) if poly: results.append({"polygon": poly, "score": s}) return results def main(): parser = argparse.ArgumentParser() parser.add_argument("--input", required=True, help="Input image path") parser.add_argument("--output", required=True, help="Output vis image path") parser.add_argument("--points", type=int, default=32, help="Points per side") args = parser.parse_args() buf = np.fromfile(args.input, dtype=np.uint8) image = cv2.imdecode(buf, cv2.IMREAD_COLOR) if image is None: print(f"Failed to read {args.input}") return # Shrink image if too large just to make NMS faster h, w = image.shape[:2] max_dim = 1024 if max(h, w) > max_dim: scale = max_dim / max(h, w) image_proc = cv2.resize(image, (int(w * scale), int(h * scale))) else: image_proc = image.copy() model_path = os.path.join(server_path, "sam3.pt") results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7) vis = image_proc.copy() np.random.seed(42) for res in results: poly = res["polygon"] pts = np.array(poly, dtype=np.int32) color = np.random.randint(0, 255, (3,)).tolist() overlay = vis.copy() cv2.fillPoly(overlay, [pts], color) cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis) cv2.polylines(vis, [pts], True, color, 1) # Fix unicode paths in output is_success, im_buf_arr = cv2.imencode(".jpg", vis) if is_success: im_buf_arr.tofile(args.output) print(f"Saved visualization to {args.output}") if __name__ == "__main__": main()