186 lines
6.1 KiB
Python
186 lines
6.1 KiB
Python
import argparse
|
|
import sys
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
|
|
# Add server to path so we can import sam3 locally
|
|
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
|
|
models_path = os.path.join(server_path, "app", "models")
|
|
if server_path not in sys.path:
|
|
sys.path.insert(0, server_path)
|
|
if models_path not in sys.path:
|
|
sys.path.insert(0, models_path)
|
|
|
|
from sam3.model_builder import build_sam3_image_model
|
|
from sam3.model.sam3_image_processor import Sam3Processor
|
|
from app.models.segment_anything_3 import SegmentAnything3
|
|
|
|
def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
|
|
offset = 1.0 / (2 * n_per_side)
|
|
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
|
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
|
|
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
|
|
return grid
|
|
|
|
def mask_iou(mask1, mask2):
|
|
inter = np.logical_and(mask1, mask2).sum()
|
|
union = np.logical_or(mask1, mask2).sum()
|
|
if union == 0:
|
|
return 0
|
|
return inter / union
|
|
|
|
def mask_to_polygon(mask, epsilon_factor=0.001):
|
|
mask = np.squeeze(mask)
|
|
mask_uint8 = (mask > 0).astype(np.uint8)
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
if not contours:
|
|
return []
|
|
largest = max(contours, key=cv2.contourArea)
|
|
if epsilon_factor > 0:
|
|
epsilon = epsilon_factor * cv2.arcLength(largest, True)
|
|
approx = cv2.approxPolyDP(largest, epsilon, True)
|
|
else:
|
|
approx = largest
|
|
points = [[float(p[0][0]), float(p[0][1])] for p in approx]
|
|
return points
|
|
|
|
def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5):
|
|
print("Loading SAM3 Model locally...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
|
|
|
|
model = build_sam3_image_model(
|
|
bpe_path=bpe_path,
|
|
device=device,
|
|
checkpoint_path=model_path,
|
|
)
|
|
|
|
if device == "cuda":
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device)
|
|
|
|
# PIL image format
|
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
|
from PIL import Image
|
|
pil_image = Image.fromarray(image_rgb)
|
|
|
|
print("Computing image embedding...")
|
|
t0 = time.time()
|
|
state = processor.set_image(pil_image)
|
|
print(f"Image embedding done in {time.time() - t0:.2f}s")
|
|
|
|
grid_points = build_point_grid(points_per_side)
|
|
print(f"Generated {len(grid_points)} grid points for sampling.")
|
|
|
|
masks = []
|
|
scores = []
|
|
|
|
t0 = time.time()
|
|
for i, (nx, ny) in enumerate(grid_points):
|
|
if i % 100 == 0:
|
|
print(f" Processed {i}/{len(grid_points)} points...")
|
|
|
|
processor.reset_all_prompts(state)
|
|
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
|
|
|
|
if "masks" in state and len(state["masks"]) > 0:
|
|
# Take the mask with the highest score
|
|
best_idx = torch.argmax(state["scores"])
|
|
mask = state["masks"][best_idx].cpu().numpy()
|
|
score = state["scores"][best_idx].item()
|
|
if score > conf_thresh:
|
|
masks.append(mask)
|
|
scores.append(score)
|
|
print(f"Grid prediction done in {time.time() - t0:.2f}s")
|
|
print(f"Found {len(masks)} raw masks.")
|
|
|
|
if not masks:
|
|
return []
|
|
|
|
# Simple NMS based on IoU
|
|
print("Applying NMS...")
|
|
order = np.argsort(scores)[::-1]
|
|
keep = []
|
|
|
|
for idx in order:
|
|
if len(keep) == 0:
|
|
keep.append(idx)
|
|
continue
|
|
|
|
current_mask = masks[idx]
|
|
overlap = False
|
|
for k in keep:
|
|
iou = mask_iou(current_mask, masks[k])
|
|
if iou > nms_thresh:
|
|
overlap = True
|
|
break
|
|
if not overlap:
|
|
keep.append(idx)
|
|
|
|
final_masks = [masks[idx] for idx in keep]
|
|
final_scores = [scores[idx] for idx in keep]
|
|
print(f"Kept {len(final_masks)} masks after NMS.")
|
|
|
|
# Convert to polygons
|
|
results = []
|
|
for m, s in zip(final_masks, final_scores):
|
|
poly = mask_to_polygon(m)
|
|
if poly:
|
|
results.append({"polygon": poly, "score": s})
|
|
|
|
return results
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input", required=True, help="Input image path")
|
|
parser.add_argument("--output", required=True, help="Output vis image path")
|
|
parser.add_argument("--points", type=int, default=32, help="Points per side")
|
|
args = parser.parse_args()
|
|
|
|
buf = np.fromfile(args.input, dtype=np.uint8)
|
|
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
|
if image is None:
|
|
print(f"Failed to read {args.input}")
|
|
return
|
|
|
|
# Shrink image if too large just to make NMS faster
|
|
h, w = image.shape[:2]
|
|
max_dim = 1024
|
|
if max(h, w) > max_dim:
|
|
scale = max_dim / max(h, w)
|
|
image_proc = cv2.resize(image, (int(w * scale), int(h * scale)))
|
|
else:
|
|
image_proc = image.copy()
|
|
|
|
model_path = os.path.join(server_path, "sam3.pt")
|
|
|
|
results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7)
|
|
|
|
vis = image_proc.copy()
|
|
np.random.seed(42)
|
|
for res in results:
|
|
poly = res["polygon"]
|
|
pts = np.array(poly, dtype=np.int32)
|
|
color = np.random.randint(0, 255, (3,)).tolist()
|
|
|
|
overlay = vis.copy()
|
|
cv2.fillPoly(overlay, [pts], color)
|
|
cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis)
|
|
cv2.polylines(vis, [pts], True, color, 1)
|
|
|
|
# Fix unicode paths in output
|
|
is_success, im_buf_arr = cv2.imencode(".jpg", vis)
|
|
if is_success:
|
|
im_buf_arr.tofile(args.output)
|
|
print(f"Saved visualization to {args.output}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|