initial_tune

This commit is contained in:
AI-team\cyhan
2025-05-12 11:23:49 +09:00
commit b436a81091
33 changed files with 2398 additions and 0 deletions

View File

View File

@@ -0,0 +1,218 @@
import enum
import json
import os
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
from segment_anything_ui.draw_label import PaintType
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
class MergeState(enum.Enum):
PICKING = enum.auto()
MERGING = enum.auto()
class AnnotationLayout(QWidget):
def __init__(self, parent, config) -> None:
super().__init__(parent)
self.config = config
self.zoom_flag = False
self.merge_state = MergeState.PICKING
self.layout = QVBoxLayout(self)
labels = self._load_labels(config)
self.layout.setAlignment(Qt.AlignTop)
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.label_picker = QListWidget()
self.label_picker.addItems(labels)
self.label_picker.setCurrentRow(0)
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
self.remove_hidden_masks_line = QLineEdit("0.5")
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
for w in [
self.add_point,
self.add_box,
self.annotate_all,
self.pick_mask,
self.pick_bounding_box,
self.merge_masks,
self.move_current_mask_background,
self.cancel_annotation,
self.delete_mask,
self.partial_annotation,
self.save_annotation,
self.manual_polygon,
self.label_picker,
self.zoom_rectangle,
self.annotation_settings,
self.remove_hidden_masks,
self.remove_hidden_masks_label,
self.remove_hidden_masks_line
]:
self.layout.addWidget(w)
self.add_point.clicked.connect(self.on_add_point)
self.add_box.clicked.connect(self.on_add_box)
self.annotate_all.clicked.connect(self.on_annotate_all)
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
self.save_annotation.clicked.connect(self.on_save_annotation)
self.pick_mask.clicked.connect(self.on_pick_mask)
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
self.manual_polygon.clicked.connect(self.on_manual_polygon)
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
self.merge_masks.clicked.connect(self.on_merge_masks)
self.partial_annotation.clicked.connect(self.on_partial_annotation)
self.delete_mask.clicked.connect(self.on_delete_mask)
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
def on_delete_mask(self):
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
self.parent().info_label.setText("Deleting mask!")
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.bounding_boxes.remove(mask_uid)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
self.parent().info_label.setText("Deleting bounding box!")
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
self.parent().annotator.bounding_boxes.bounding_box_id)
if mask_uid is not None:
self.parent().annotator.masks.pop_by_uuid(mask_uid)
self.parent().annotator.bounding_boxes.bounding_box_id = -1
self.parent().annotator.last_mask = None
self.parent().annotator.masks.mask_id = -1
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
def on_partial_annotation(self):
self.parent().info_label.setText("Partial annotation!")
self.parent().annotator.pick_partial_mask()
self.parent().image_label.clear()
@staticmethod
def _load_labels(config):
if not os.path.exists(config.label_file):
return ["default"]
with open(config.label_file, "r") as f:
labels = json.load(f)
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
return labels
def on_merge_masks(self):
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
if self.merge_state == MergeState.PICKING:
self.parent().info_label.setText("Pick a mask to merge with!")
self.merge_state = MergeState.MERGING
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
elif self.merge_state == MergeState.MERGING:
self.parent().info_label.setText("Merging masks!")
self.parent().annotator.merge_masks()
self.merge_state = MergeState.PICKING
def on_move_current_mask_background_fn(self):
self.parent().info_label.setText("Moving current mask to background!")
self.parent().annotator.move_current_mask_to_background()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_remove_hidden_masks(self):
self.parent().info_label.setText("Removing hidden masks!")
annotations = self.parent().annotator.masks
argmax_mask = self.parent().annotator.make_instance_mask()
limit_ratio = float(self.remove_hidden_masks_line.text())
new_masks = []
new_labels = []
for idx, (mask, label) in enumerate(annotations):
num_pixels = np.sum(mask > 0)
num_visible = np.sum(argmax_mask == (idx + 1))
ratio = num_visible / num_pixels
if ratio > limit_ratio:
new_masks.append(mask)
new_labels.append(label)
print("Removed ", len(annotations) - len(new_masks), " masks.")
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_pick_mask(self):
self.parent().info_label.setText("Pick a mask to do modifications!")
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
def on_pick_bounding_box(self):
self.parent().info_label.setText("Pick a bounding box to do modifications!")
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
def on_manual_polygon(self):
# Sets emphasis on the button
self.manual_polygon.setProperty("active", True)
self.parent().image_label.change_paint_type(PaintType.POLYGON)
def on_add_point(self):
self.parent().info_label.setText("Adding point annotation!")
self.parent().image_label.change_paint_type(PaintType.POINT)
def on_add_box(self):
self.parent().info_label.setText("Adding box annotation!")
self.parent().image_label.change_paint_type(PaintType.BOX)
def on_zoom_rectangle(self):
if self.zoom_flag:
self.parent().info_label.setText("Zooming rectangle OFF!")
self.parent().image_label.change_paint_type(PaintType.POINT)
self.parent().annotator.zoomed_bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self.zoom_flag = False
else:
self.parent().info_label.setText("Pick Mask to zoom!")
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
self.zoom_flag = True
def on_annotate_all(self):
self.parent().info_label.setText("Annotating all. This make take some time.")
self.parent().annotator.predict_all(self.annotation_settings.get_values())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().info_label.setText("Annotate all finished.")
def on_cancel_annotation(self):
self.parent().image_label.clear()
self.parent().annotator.clear_last_masks()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_save_annotation(self):
if self.parent().image_label.paint_type == PaintType.POLYGON:
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
self.config.window_size[0], self.config.window_size[1])
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().image_label.clear()

View File

@@ -0,0 +1,445 @@
import dataclasses
from typing import Callable
import uuid
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
from segment_anything import SamPredictor
from segment_anything.build_sam import Sam
from segment_anything_ui.model_builder import (
get_predictor, get_mask_generator, SamPredictor)
try:
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
except (ImportError, ModuleNotFoundError):
class EfficientViTSamPredictor:
pass
class EfficientViTSam:
pass
from skimage.measure import regionprops
import torch
from segment_anything_ui.utils.shapes import BoundingBox
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
def get_cmap(n, name='hsv'):
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
RGB color; the keyword argument name must be a standard mpl colormap name.'''
try:
return plt.cm.get_cmap(name, n)
except:
return plt.get_cmap(name, n)
def crop_image(
image,
box: BoundingBox | None = None,
image_shape: tuple[int, int] | None = None
):
if image_shape is None:
image_shape = image.shape[:2][::-1]
if box is None:
return cv2.resize(image, image_shape)
if len(image.shape) == 2:
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
def insert_image(image, box: BoundingBox | None = None):
new_image = np.zeros_like(image)
if box is None:
new_image = image
else:
if len(image.shape) == 2:
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
else:
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
return new_image
@dataclasses.dataclass()
class AutomaticMaskGeneratorSettings:
points_per_side: int = 32
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
class LabelValueParam(QWidget):
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
self.layout.addWidget(self.label)
self.layout.addWidget(self.value)
self.converter = value_type_converter
def get_value(self):
return self.converter(self.value.text())
class CustomForm(QWidget):
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.widgets = []
for field in dataclasses.fields(automatic_mask_generator_settings):
widget = LabelValueParam(field.name, str(field.default), field.type)
self.widgets.append(widget)
self.layout.addWidget(widget)
def get_values(self):
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
class BoundingBoxAnnotation:
def __init__(self) -> None:
self.bounding_boxes: list[BoundingBox] = []
self.bounding_box_id: int = -1
def append(self, bounding_box: BoundingBox):
self.bounding_boxes.append(bounding_box)
def find_closest_bounding_box(self, point: np.ndarray):
closest_bounding_box = None
closest_bounding_box_id = -1
min_distance = float('inf')
for idx, bounding_box in enumerate(self.bounding_boxes):
distance = bounding_box.distance_to(point)
if distance < min_distance and bounding_box.contains(point):
min_distance = distance
closest_bounding_box = bounding_box
closest_bounding_box_id = idx
self.bounding_box_id = closest_bounding_box_id
return closest_bounding_box, closest_bounding_box_id
def get_bounding_box(self, bounding_box_id: int):
return self.bounding_boxes[bounding_box_id]
def get_current_bounding_box(self):
return self.bounding_boxes[self.bounding_box_id]
def set_current_bounding_box(self, bounding_box: BoundingBox):
self.bounding_boxes[self.bounding_box_id] = bounding_box
def remove(self, mask_uid: str):
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
if bounding_box_id is None:
return
bounding_box = self.bounding_boxes.pop(bounding_box_id)
if self.bounding_box_id >= bounding_box_id:
self.bounding_box_id -= 1
return bounding_box
def remove_by_id(self, bounding_box_id: int):
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
self.remove(mask_uid)
return mask_uid
def __len__(self):
return len(self.bounding_boxes)
class MasksAnnotation:
DEFAULT_LABEL = "default"
def __init__(self) -> None:
self.masks = []
self.label_map = {}
self.masks_uids: list[str] = []
self.mask_id: int = -1
def add_mask(self, mask, label: str | None = None):
self.masks.append(mask)
self.masks_uids.append(str(uuid.uuid4()))
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
return self.masks_uids[-1]
def add_label(self, mask_id: int, label: str):
self.label_map[mask_id] = label
def get_mask(self, mask_id: int):
return self.masks[mask_id]
def get_label(self, mask_id: int):
return self.label_map[mask_id]
def get_current_mask(self):
return self.masks[self.mask_id]
def set_current_mask(self, mask, label: str = None):
self.masks[self.mask_id] = mask
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
def __getitem__(self, mask_id: int):
return self.get_mask(mask_id)
def __setitem__(self, mask_id: int, value):
self.masks[mask_id] = value
def __len__(self):
return len(self.masks)
def __iter__(self):
return iter(zip(self.masks, self.label_map.values()))
def __next__(self):
if self.mask_id >= len(self.masks):
raise StopIteration
return self.masks[self.mask_id]
def append(self, mask, label: str | None = None):
return self.add_mask(mask, label)
def pop_by_uuid(self, mask_uid: str):
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
if mask_id is None:
return
return self.pop(mask_id)
def pop(self, mask_id: int = -1):
_ = self.masks.pop(mask_id)
mask_uid = self.masks_uids.pop(mask_id)
self.label_map.pop(mask_id + 1)
new_label_map = {}
for index, value in enumerate(self.label_map.values()):
new_label_map[index + 1] = value
self.label_map = new_label_map
return mask_uid
@classmethod
def from_masks(cls, masks, labels: list[str] | None = None):
annotation = cls()
if labels is None:
labels = [None] * len(masks)
for mask, label in zip(masks, labels):
annotation.append(mask, label)
return annotation
@dataclasses.dataclass()
class Annotator:
sam: Sam | EfficientViTSam | None = None
embedding: torch.Tensor | None = None
image: np.ndarray | None = None
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
predictor: SamPredictor | EfficientViTSamPredictor | None = None
visualization: np.ndarray | None = None
last_mask: np.ndarray | None = None
partial_mask: np.ndarray | None = None
merged_mask: np.ndarray | None = None
parent: QWidget | None = None
cmap: plt.cm = None
original_image: np.ndarray | None = None
zoomed_bounding_box: BoundingBox | None = None
def __post_init__(self):
self.MAX_MASKS = 10
self.cmap = get_cmap(self.MAX_MASKS)
def set_image(self, image: np.ndarray):
self.image = image
return self
def make_embedding(self):
if self.sam is None:
return
self.predictor = get_predictor(self.sam)
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
generator = get_mask_generator(
sam=self.sam,
**dataclasses.asdict(settings)
)
masks = generator.generate(self.image)
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
label = self.parent.annotation_layout.label_picker.currentItem().text()
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
self.cmap = get_cmap(len(self.masks))
def make_prediction(self, annotation: dict):
masks, scores, logits = self.predictor.predict(
point_coords=annotation["points"],
point_labels=annotation["labels"],
box=annotation["bounding_boxes"],
multimask_output=False
)
mask = masks[0]
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
def pick_partial_mask(self):
if self.partial_mask is None:
self.partial_mask = self.last_mask.copy()
else:
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
self.last_mask = None
def move_current_mask_to_background(self):
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
def merge_masks(self):
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
self.merged_mask = None
def visualize_last_mask(self, label: str | None = None):
last_mask = np.zeros_like(self.image)
last_mask[:, :, 1] = self.last_mask
if self.partial_mask is not None:
last_mask[:, :, 0] = self.partial_mask
if self.merged_mask is not None:
last_mask[:, :, 2] = self.merged_mask
if label is not None:
props = regionprops(self.last_mask)[0]
cv2.putText(
last_mask,
label,
(int(props.centroid[1]), int(props.centroid[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
if self.is_show_bounding_boxes:
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
cv2.rectangle(
last_mask,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
(0, 255, 0),
2
)
cv2.putText(
last_mask,
label,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
0.8, last_mask, 0.5, 0)
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
def _visualize_mask(self) -> tuple:
mask_argmax = self.make_instance_mask()
visualization = np.zeros_like(self.image)
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
for i in range(1, np.amax(mask_argmax) + 1):
color = self.cmap(i)
single_mask = np.zeros_like(mask_argmax)
single_mask[mask_argmax == i] = 1
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
border += single_mask - cv2.erode(
single_mask, np.ones((3, 3), np.uint8), iterations=1)
label = self.masks.get_label(i)
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
if np.isnan(single_mask_center).any():
continue
if self.parent.settings.is_show_text():
cv2.putText(
visualization,
label,
(int(single_mask_center[1]), int(single_mask_center[0])),
cv2.FONT_HERSHEY_PLAIN,
0.5,
[255, 255, 255],
1
)
if self.is_show_bounding_boxes:
bounding_boxes = self.get_bounding_boxes()
for idx, bounding_box in enumerate(bounding_boxes):
cv2.rectangle(
visualization,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
2
)
cv2.putText(
visualization,
bounding_box.label,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
border = (border == 0).astype(np.uint8)
return visualization, border
def has_annotations(self):
return len(self.masks) > 0
def make_instance_mask(self):
background = np.zeros_like(self.masks[0]) + 1
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
return mask_argmax
def get_bounding_boxes(self):
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
def merge_image_visualization(self):
image = self.image.copy()
if not len(self.masks):
return crop_image(image, self.zoomed_bounding_box)
visualization, border = self._visualize_mask()
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
return crop_image(self.visualization, self.zoomed_bounding_box)
def remove_last_mask(self):
mask_id = len(self.masks)
self.masks.pop(mask_id)
self.bounding_boxes.remove(mask_id)
def make_labels(self):
return self.masks.label_map
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
if self.partial_mask is not None:
last_mask = self.partial_mask
self.partial_mask = None
else:
last_mask = self.last_mask
mask_uid = self.masks.add_mask(last_mask, label=label)
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
corresponding_bounding_box.mask_uid = mask_uid
self.bounding_boxes.append(corresponding_bounding_box)
if len(self.masks) >= self.MAX_MASKS:
self.MAX_MASKS += 10
self.cmap = get_cmap(self.MAX_MASKS)
@property
def is_show_bounding_boxes(self):
return self.parent.settings.is_show_bounding_boxes()
def clear_last_masks(self):
self.last_mask = None
self.partial_mask = None
self.visualization = None
def clear(self):
self.last_mask = None
self.visualization = None
self.masks = MasksAnnotation()
self.bounding_boxes = BoundingBoxAnnotation()
self.partial_mask = None

View File

@@ -0,0 +1,143 @@
import dataclasses
import os
from typing import Literal
from PySide6.QtCore import Qt
import requests
try:
from tqdm import tqdm
import wget
except ImportError:
print("Tqdm and wget not found. Install with pip install tqdm wget")
tqdm = None
wget = None
@dataclasses.dataclass(frozen=True)
class Keymap:
key: Qt.Key | str
name: str
class ProgressBar:
def __init__(self):
self.progress_bar = None
def __call__(self, current_bytes, total_bytes, width):
current_mb = round(current_bytes / 1024 ** 2, 1)
total_mb = round(total_bytes / 1024 ** 2, 1)
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_mb, desc="MB")
delta_mb = current_mb - self.progress_bar.n
self.progress_bar.update(delta_mb)
@dataclasses.dataclass
class KeyBindings:
ADD_POINT: Keymap = Keymap(Qt.Key.Key_W, "W")
ADD_BOX: Keymap = Keymap(Qt.Key.Key_Q, "Q")
ANNOTATE_ALL: Keymap = Keymap(Qt.Key.Key_Return, "Enter")
MANUAL_POLYGON: Keymap = Keymap(Qt.Key.Key_R, "R")
CANCEL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_C, "C")
SAVE_ANNOTATION: Keymap = Keymap(Qt.Key.Key_S, "S")
PICK_MASK: Keymap = Keymap(Qt.Key.Key_X, "X")
PICK_BOUNDING_BOX: Keymap = Keymap(Qt.Key.Key_B, "B")
MERGE_MASK: Keymap = Keymap(Qt.Key.Key_Z, "Z")
DELETE_MASK: Keymap = Keymap(Qt.Key.Key_V, "V")
PARTIAL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_D, "D")
SAVE_BOUNDING_BOXES: Keymap = Keymap("Ctrl+B", "Ctrl+B")
NEXT_FILE: Keymap = Keymap(Qt.Key.Key_F, "F")
PREVIOUS_FILE: Keymap = Keymap(Qt.Key.Key_G, "G")
SAVE_MASK: Keymap = Keymap("Ctrl+S", "Ctrl+S")
PRECOMPUTE: Keymap = Keymap(Qt.Key.Key_P, "P")
ZOOM_RECTANGLE: Keymap = Keymap(Qt.Key.Key_E, "E")
@dataclasses.dataclass
class Config:
default_weights: Literal[
"sam_vit_b_01ec64.pth",
"sam_vit_h_4b8939.pth",
"sam_vit_l_0b3195.pth",
"xl0.pt",
"xl1.pt",
"sam_hq_vit_b.pth",
"sam_hq_vit_l.pth",
"sam_hq_vit_h.pth",
"sam_hq_vit_tiny.pth",
"sam2.1_hiera_t.pth",
"sam2.1_hiera_l.pth",
"sam2.1_hiera_b+.pth",
"sam2.1_hiera_s.pth",
] = "sam_vit_h_4b8939.pth"
download_weights_if_not_available: bool = True
label_file: str = "labels.json"
window_size: tuple[int, int] | int = (1920, 1080)
key_mapping: KeyBindings = dataclasses.field(default_factory=KeyBindings)
weights_paths: dict[str, str] = dataclasses.field(default_factory=lambda: {
"l2": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"xl0": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt",
"xl1": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt",
"hq_vit_b": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
"hq_vit_l": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
"hq_vit_h": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"hq_vit_tiny": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
"sam2.1_hiera_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"sam2.1_hiera_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"sam2.1_hiera_b+": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"sam2.1_hiera_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
})
def __post_init__(self):
if isinstance(self.window_size, int):
self.window_size = (self.window_size, self.window_size)
if self.download_weights_if_not_available:
self.download_weights()
def get_sam_model_name(self):
if "l2" in self.default_weights:
return "l2"
if "sam_vit_b" in self.default_weights:
return "vit_b"
if "sam_vit_h" in self.default_weights:
return "vit_h"
if "sam_vit_l" in self.default_weights:
return "vit_l"
if "xl0" in self.default_weights:
return "xl0"
if "xl1" in self.default_weights:
return "xl1"
if "hq_vit_b" in self.default_weights:
return "hq_vit_b"
if "hq_vit_l" in self.default_weights:
return "hq_vit_l"
if "hq_vit_h" in self.default_weights:
return "hq_vit_h"
if "hq_vit_tiny" in self.default_weights:
return "hq_vit_tiny"
if "sam2.1_hiera_t" in self.default_weights:
return "sam2.1_hiera_t"
if "sam2.1_hiera_l" in self.default_weights:
return "sam2.1_hiera_l"
if "sam2.1_hiera_b+" in self.default_weights:
return "sam2.1_hiera_b+"
if "sam2.1_hiera_s" in self.default_weights:
return "sam2.1_hiera_s"
raise ValueError("Unknown model name")
def download_weights(self):
if not os.path.exists(self.default_weights):
try:
print(f"Downloading weights for model {self.get_sam_model_name()}")
wget.download(self.weights_paths[self.get_sam_model_name()], self.default_weights, bar=ProgressBar())
print(f"Downloaded weights to {self.default_weights}")
except Exception as e:
print(f"Error downloading weights: {e}. Trying with requests.")
model_name = self.get_sam_model_name()
print(f"Downloading weights for model {model_name}")
file = requests.get(self.weights_paths[model_name])
with open(self.default_weights, "wb") as f:
f.write(file.content)
print(f"Downloaded weights to {self.default_weights}")

View File

@@ -0,0 +1,268 @@
import copy
from enum import Enum
import cv2
import numpy as np
import PySide6
from PySide6 import QtCore, QtWidgets
from PySide6.QtGui import QPainter, QPen
from segment_anything_ui.config import Config
from segment_anything_ui.utils.shapes import BoundingBox, Polygon
class PaintType(Enum):
POINT = 0
BOX = 1
MASK = 2
POLYGON = 3
MASK_PICKER = 4
ZOOM_PICKER = 5
BOX_PICKER = 6
class MaskIdPicker:
def __init__(self, length) -> None:
self.counter = 0
self.length = length
def increment(self):
self.counter = (self.counter + 1) % self.length
def pick(self, ids):
print("Length of ids: ", len(ids), " counter: ", self.counter, " ids: ", ids)
if len(ids) <= self.counter:
self.counter = 0
return_id = ids[self.counter]
self.increment()
return return_id
class DrawLabel(QtWidgets.QLabel):
def __init__(self, parent=None):
super().__init__(parent)
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0)
self._paint_type = PaintType.POINT
self.polygon = Polygon()
self.mask_enum: MaskIdPicker = MaskIdPicker(3)
self.config = Config()
self.setFocusPolicy(QtCore.Qt.StrongFocus)
self._zoom_center = (0, 0)
self._zoom_factor = 1.0
self._zoom_bounding_box: BoundingBox | None = None
def paintEvent(self, paint_event):
painter = QPainter(self)
painter.drawPixmap(self.rect(), self.pixmap())
pen_positive = self._get_pen(QtCore.Qt.green, 5)
pen_negative = self._get_pen(QtCore.Qt.red, 5)
pen_partial = self._get_pen(QtCore.Qt.yellow, 1)
pen_box = self._get_pen(QtCore.Qt.green, 1)
painter.setRenderHint(QPainter.Antialiasing, False)
painter.setPen(pen_box)
if self.bounding_box is not None and self.bounding_box.xend != -1 and self.bounding_box.yend != -1:
painter.drawRect(
self.bounding_box.xstart,
self.bounding_box.ystart,
self.bounding_box.xend - self.bounding_box.xstart,
self.bounding_box.yend - self.bounding_box.ystart
)
painter.setPen(pen_partial)
painter.drawRect(self.partial_box.xstart, self.partial_box.ystart,
self.partial_box.xend - self.partial_box.xstart,
self.partial_box.yend - self.partial_box.ystart)
painter.setPen(pen_positive)
for pos in self.positive_points:
painter.drawPoint(pos)
painter.setPen(pen_negative)
painter.setRenderHint(QPainter.Antialiasing, False)
for pos in self.negative_points:
painter.drawPoint(pos)
if self.polygon.is_plotable():
painter.setPen(pen_box)
painter.setRenderHint(QPainter.Antialiasing, True)
painter.drawPolygon(self.polygon.to_qpolygon())
# self.update()
def _get_pen(self, color=QtCore.Qt.red, width=5):
pen = QPen()
pen.setWidth(width)
pen.setColor(color)
return pen
@property
def paint_type(self):
return self._paint_type
def change_paint_type(self, paint_type: PaintType):
print(f"Changing paint type to {paint_type}")
self._paint_type = paint_type
def mouseMoveEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
self.partial_box = copy.deepcopy(self.bounding_box)
self.partial_box.xend = ev.pos().x()
self.partial_box.yend = ev.pos().y()
self.update()
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.buttons() == QtCore.Qt.LeftButton:
self._move_update(None, point)
elif ev.buttons() == QtCore.Qt.RightButton:
self._move_update(point, None)
else:
pass
self.update()
def _move_update(self, temporary_point_negative, temporary_point_positive):
annotations = self.get_annotations(temporary_point_positive, temporary_point_negative)
self.parent().annotator.make_prediction(annotations)
self.parent().annotator.visualize_last_mask()
def mouseReleaseEvent(self, cursor_event):
if self._paint_type == PaintType.POINT:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.positive_points.append(cursor_event.pos())
print(self.size())
elif cursor_event.button() == QtCore.Qt.RightButton:
self.negative_points.append(cursor_event.pos())
# self.chosen_points.append(self.mapFromGlobal(QtGui.QCursor.pos()))
elif self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.bounding_box.xend = cursor_event.pos().x()
self.bounding_box.yend = cursor_event.pos().y()
self.partial_box = BoundingBox(-1, -1, -1, -1)
if not self._paint_type == PaintType.MASK_PICKER and \
not self._paint_type == PaintType.ZOOM_PICKER and \
not self._paint_type == PaintType.POLYGON and \
not self._paint_type == PaintType.BOX_PICKER:
self.parent().annotator.make_prediction(self.get_annotations())
self.parent().annotator.visualize_last_mask()
if self._paint_type == PaintType.ZOOM_PICKER:
self.parent().annotator.zoomed_bounding_box = self.bounding_box.scale(*self._get_scale()).to_int()
self.bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self._paint_type = PaintType.POINT
self.update()
def mousePressEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER] and ev.button() == QtCore.Qt.LeftButton:
self.bounding_box = BoundingBox(xstart=ev.pos().x(), ystart=ev.pos().y())
if self._paint_type == PaintType.POLYGON and ev.button() == QtCore.Qt.LeftButton:
self.polygon.points.append([ev.pos().x(), ev.pos().y()])
if self._paint_type == PaintType.MASK_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
int(ev.pos().x() / size.width() * self.config.window_size[0]),
int(ev.pos().y() / size.height() * self.config.window_size[1])]
masks = np.array(self.parent().annotator.masks.masks)
mask_ids = np.where(masks[:, point[1], point[0]])[0]
print("Picking mask at point: {}".format(point))
if not(len(mask_ids)):
print("No mask found")
mask_id = -1
local_mask = np.zeros((masks.shape[1], masks.shape[2]))
label = None
else:
mask_id = self.mask_enum.pick(mask_ids)
local_mask = self.parent().annotator.masks.get_mask(mask_id)
label = self.parent().annotator.masks.get_label(mask_id + 1)
self.parent().annotator.masks.mask_id = mask_id
self.parent().annotator.last_mask = local_mask
self.parent().annotator.visualize_last_mask(label)
if self._paint_type == PaintType.BOX_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
float(ev.pos().x() / size.width()),
float(ev.pos().y() / size.height())]
bounding_box, bounding_box_id = self.parent().annotator.bounding_boxes.find_closest_bounding_box(point)
if bounding_box is None:
print("No bounding box found")
else:
self.parent().annotator.bounding_boxes.bounding_box_id = bounding_box_id
print(f"Bounding box: {bounding_box}")
print(f"Bounding box id: {bounding_box_id}")
self.parent().update(self.parent().annotator.merge_image_visualization())
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.button() == QtCore.Qt.LeftButton:
self._move_update(None, point)
if ev.button() == QtCore.Qt.RightButton:
self._move_update(point, None)
self.update()
def zoom_to_rectangle(self, xstart, ystart, xend, yend):
picked_image = self.parent().annotator.image[ystart:yend, xstart:xend, :]
self.parent().annotator.image = cv2.resize(picked_image, (self.config.window_size[0], self.config.window_size[1]))
self.update()
def keyPressEvent(self, ev: PySide6.QtGui.QKeyEvent) -> None:
print(ev.key())
if self._paint_type == PaintType.MASK_PICKER and ev.key() == QtCore.Qt.Key.Key_D and len(self.parent().annotator.masks):
print("Deleting mask")
self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
def _get_scale(self):
return self.config.window_size[0] / self.size().width(), self.config.window_size[1] / self.size().height()
def get_annotations(
self,
temporary_point_positive: PySide6.QtCore.QPoint | None = None,
temporary_point_negative: PySide6.QtCore.QPoint | None = None
):
sx, sy = self._get_scale()
positive_points = [(
p.x() * sx,
p.y() * sy) for p in self.positive_points]
negative_points = [(
p.x() * sx,
p.y() * sy) for p in self.negative_points]
if temporary_point_positive:
positive_points += [(temporary_point_positive.x() * sx, temporary_point_positive.y() * sy)]
if temporary_point_negative:
negative_points += [(temporary_point_negative.x() * sx, temporary_point_negative.y() * sy)]
positive_points = np.array(positive_points).reshape(-1, 2)
negative_points = np.array(negative_points).reshape(-1, 2)
labels = np.array([1, ] * len(positive_points) + [0, ] * len(negative_points))
print(f"Positive points: {positive_points}")
print(f"Negative points: {negative_points}")
print(f"Labels: {labels}")
return {
"points": np.concatenate([positive_points, negative_points], axis=0),
"labels": labels,
"bounding_boxes": self.bounding_box.scale(sx, sy).to_numpy() if self.bounding_box else None
}
def clear(self):
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0, 0, 0)
self.polygon = Polygon()
self.update()

View File

@@ -0,0 +1,14 @@
from PySide6.QtGui import QImage, QPixmap, QPainter, QPen
from PySide6.QtCore import Qt
class ImagePixmap(QPixmap):
def __init__(self):
super().__init__()
def set_image(self, image):
if image.dtype == "uint8":
image = image.astype("float32") / 255.0
image = (image * 255).astype("uint8")
image = QImage(image.data, image.shape[1], image.shape[0], QImage.Format_RGB888)
self.convertFromImage(image)

View File

@@ -0,0 +1,81 @@
import logging
import sys
import cv2
import numpy as np
import torch
from PySide6.QtWidgets import (QApplication, QGridLayout, QLabel,
QMessageBox, QWidget)
from PySide6.QtCore import Qt
from segment_anything_ui.annotator import Annotator
from segment_anything_ui.annotation_layout import AnnotationLayout
from segment_anything_ui.config import Config
from segment_anything_ui.draw_label import DrawLabel
from segment_anything_ui.image_pixmap import ImagePixmap
from segment_anything_ui.model_builder import build_model
from segment_anything_ui.settings_layout import SettingsLayout
class SegmentAnythingUI(QWidget):
def __init__(self, config) -> None:
super().__init__()
self.config: Config = config
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.setWindowTitle("Segment Anything UI")
self.setWindowState(Qt.WindowState.WindowMaximized)
# self.setGeometry(100, 100, 800, 600)
self.layout = QGridLayout(self)
self.image_label = DrawLabel(self)
self.settings = SettingsLayout(self, config=self.config)
self.info_label = QLabel("Information about running process.")
self.sam = self.init_sam()
self.annotator = Annotator(sam=self.sam, parent=self)
self.annotation_layout = AnnotationLayout(self, config=self.config)
self.layout.addWidget(self.annotation_layout, 0, 0, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.image_label, 0, 1, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.settings, 0, 3, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.info_label, 1, 1, Qt.AlignBottom)
self.set_image(np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8))
self.show()
def set_image(self, image: np.ndarray, clear_annotations: bool = True):
self.annotator.set_image(image).make_embedding()
if clear_annotations:
self.annotator.clear()
self.update(image)
def update(self, image: np.ndarray):
image = cv2.resize(image, self.config.window_size)
pixmap = ImagePixmap()
pixmap.set_image(image)
print("Updating image")
self.image_label.setPixmap(pixmap)
def init_sam(self):
try:
checkpoint_path = str(self.settings.checkpoint_path.text())
sam = build_model(self.config.get_sam_model_name(), checkpoint_path, self.device)
except Exception as e:
logging.getLogger().exception(f"Error loading model: {e}")
QMessageBox.critical(self, "Error", "Could not load model")
return None
return sam
def get_mask(self):
return self.annotator.make_instance_mask()
def get_labels(self):
return self.annotator.make_labels()
def get_bounding_boxes(self):
return self.annotator.get_bounding_boxes()
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = SegmentAnythingUI(Config())
sys.exit(app.exec())

View File

@@ -0,0 +1,129 @@
import os
from PySide6.QtWidgets import QMessageBox
try:
from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator
IS_EFFICIENT_VIT_AVAILABLE = True
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
IS_EFFICIENT_VIT_AVAILABLE = False
try:
from segment_anything_hq import sam_model_registry as sam_hq_model_registry
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq
from segment_anything_hq.build_sam import Sam as SamHQ
IS_SAM_HQ_AVAILABLE = True
_SAM_HQ_MODEL_REGISTRY = {
"hq_vit_b": "vit_b",
"hq_vit_l": "vit_l",
"hq_vit_h": "vit_h",
"hq_vit_tiny": "vit_tiny",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
IS_SAM_HQ_AVAILABLE = False
_SAM_HQ_MODEL_REGISTRY = {}
try:
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
IS_SAM2_AVAILABLE = True
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize
# Reset Hydra's global configuration
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
_SAM2_MODEL_REGISTRY = {
"sam2.1_hiera_t": "sam2.1_hiera_t.yaml",
"sam2.1_hiera_l": "sam2.1_hiera_l.yaml",
"sam2.1_hiera_b": "sam2.1_hiera_b+.yaml",
"sam2.1_hiera_s": "sam2.1_hiera_s.yaml",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .")
IS_SAM2_AVAILABLE = False
_SAM2_MODEL_REGISTRY = {}
from segment_anything import sam_model_registry
from segment_anything import SamPredictor, automatic_mask_generator
from segment_anything.build_sam import Sam
def build_model(model_name: str, checkpoint_path: str, device: str):
match model_name:
case "xl0" | "xl1":
if not IS_EFFICIENT_VIT_AVAILABLE:
raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
efficientvit_sam = create_efficientvit_sam_model(
name=model_name, weight_url=checkpoint_path,
)
return efficientvit_sam.to(device).eval()
case "vit_b" | "vit_l" | "vit_h":
sam = sam_model_registry[model_name](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "hq_vit_b" | "hq_vit_l" | "hq_vit_h":
if not IS_SAM_HQ_AVAILABLE:
QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .")
raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s":
if not IS_SAM2_AVAILABLE:
QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .")
raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .")
with initialize(version_base=None, config_path="sam2_configs"):
sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device)
sam.eval()
return sam
case _:
raise ValueError(f"Model {model_name} not supported")
def get_predictor(sam):
if isinstance(sam, Sam):
return SamPredictor(sam)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamPredictor(sam)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return SamPredictorHQ(sam)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2ImagePredictor(sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")
def get_mask_generator(sam, **kwargs):
if isinstance(sam, Sam):
return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2AutomaticMaskGenerator(model=sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")

View File

View File

@@ -0,0 +1,45 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from segment_anything_ui.modeling.efficientvit.models.efficientvit import (
EfficientViTSam,
efficientvit_sam_l0,
efficientvit_sam_l1,
efficientvit_sam_l2,
)
from segment_anything_ui.modeling.efficientvit.models.nn.norm import set_norm_eps
from segment_anything_ui.modeling.efficientvit.models.utils import load_state_dict_from_file
__all__ = ["create_sam_model"]
REGISTERED_SAM_MODEL: dict[str, str] = {
"l0": "assets/checkpoints/sam/l0.pt",
"l1": "assets/checkpoints/sam/l1.pt",
"l2": "assets/checkpoints/sam/l2.pt",
}
def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam:
model_dict = {
"l0": efficientvit_sam_l0,
"l1": efficientvit_sam_l1,
"l2": efficientvit_sam_l2,
}
model_id = name.split("-")[0]
if model_id not in model_dict:
raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}")
else:
model = model_dict[model_id](**kwargs)
set_norm_eps(model, 1e-6)
if pretrained:
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
if weight_url is None:
raise ValueError(f"Do not find the pretrained weight of {name}.")
else:
weight = load_state_dict_from_file(weight_url)
model.load_state_dict(weight)
return model

View File

@@ -0,0 +1,30 @@
from safetensors import safe_open
from segment_anything.modeling import Sam
import torch.nn as nn
class ModifiedImageEncoder(nn.Module):
def __init__(self, image_encoder, saved_path: str | None = None):
super().__init__()
self.image_encoder = image_encoder
if saved_path is not None:
self.embeddings = safe_open(saved_path)
else:
self.embeddings = None
def forward(self, x):
return self.image_encoder(x) if self.embeddings is None else self.embeddings
class StorableSam:
def __init__(self, sam):
self.sam = sam
self.image_encoder = sam.image_encoder
def transform(self, saved_path):
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
def precompute(self, image):
return self.image_encoder(image)

View File

@@ -0,0 +1,116 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,120 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,119 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,121 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
# SAM decoder
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# HieraT does not currently support compilation, should always be set to False
compile_image_encoder: False

View File

@@ -0,0 +1,13 @@
import os
from typing import Any
import torch
class Saver:
def __init__(self, path: str) -> None:
self.path = path
def __call__(self, basename, mask_annotation) -> Any:
save_path = os.path.join(self.path, basename)
# TODO: finish this

View File

@@ -0,0 +1,228 @@
import json
import os
import pathlib
import random
import cv2
import numpy as np
from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox
from segment_anything_ui.annotator import MasksAnnotation
from segment_anything_ui.config import Config
from segment_anything_ui.utils.bounding_boxes import BoundingBox
class FilesHolder:
def __init__(self):
self.files = []
self.position = 0
def add_files(self, files):
self.files.extend(files)
def get_next(self):
self.position += 1
if self.position >= len(self.files):
self.position = 0
return self.files[self.position]
def get_previous(self):
self.position -= 1
if self.position < 0:
self.position = len(self.files) - 1
return self.files[self.position]
class SettingsLayout(QWidget):
MASK_EXTENSION = "_mask.png"
LABELS_EXTENSION = "_labels.json"
BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json"
def __init__(self, parent, config: Config) -> None:
super().__init__(parent)
self.config = config
self.actual_file: str = ""
self.actual_shape = self.config.window_size
self.layout = QVBoxLayout(self)
self.open_files = QPushButton("Open Files")
self.open_files.clicked.connect(self.on_open_files)
self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]")
self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]")
self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key)
self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]")
self.save_mask.clicked.connect(self.on_save_mask)
self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key)
self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]")
self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes)
self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key)
self.next_file.clicked.connect(self.on_next_file)
self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key)
self.previous_file.clicked.connect(self.on_previous_file)
self.checkpoint_path_label = QLabel(self, text="Checkpoint Path")
self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights)
self.precompute_button = QPushButton("Precompute all embeddings")
self.precompute_button.clicked.connect(self.on_precompute)
self.delete_existing_annotation = QPushButton("Delete existing annotation")
self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation)
self.show_image = QPushButton("Show Image")
self.show_visualization = QPushButton("Show Visualization")
self.show_bounding_boxes = QCheckBox("Show Bounding Boxes")
self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes)
self.show_image.clicked.connect(self.on_show_image)
self.show_visualization.clicked.connect(self.on_show_visualization)
self.show_text = QCheckBox("Show Text")
self.show_text.clicked.connect(self.on_show_text)
self.tag_text_field = QLineEdit(self)
self.tag_text_field.setPlaceholderText("Comma separated image tags")
self.layout.addWidget(self.open_files)
self.layout.addWidget(self.next_file)
self.layout.addWidget(self.previous_file)
self.layout.addWidget(self.save_mask)
self.layout.addWidget(self.save_bounding_boxes)
self.layout.addWidget(self.delete_existing_annotation)
self.layout.addWidget(self.show_text)
self.layout.addWidget(self.show_bounding_boxes)
self.layout.addWidget(self.tag_text_field)
self.layout.addWidget(self.checkpoint_path_label)
self.layout.addWidget(self.checkpoint_path)
self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed)
self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed)
self.layout.addWidget(self.precompute_button)
self.layout.addWidget(self.show_image)
self.layout.addWidget(self.show_visualization)
self.files = FilesHolder()
self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)
def on_delete_existing_annotation(self):
path = os.path.split(self.actual_file)[0]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
if os.path.exists(mask_path):
os.remove(mask_path)
if os.path.exists(labels_path):
os.remove(labels_path)
def is_show_text(self):
return self.show_text.isChecked()
def on_show_text(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_next_file(self):
file = self.files.get_next()
self._load_image(file)
def on_previous_file(self):
file = self.files.get_previous()
self._load_image(file)
def _load_image(self, file: str):
mask = file.split(".")[0] + self.MASK_EXTENSION
labels = file.split(".")[0] + self.LABELS_EXTENSION
bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
self.actual_shape = image.shape[:2][::-1]
self.actual_file = file
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.dtype in [np.float32, np.float64, np.uint16]:
image = (image / np.amax(image) * 255).astype("uint8")
#image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1)
image = cv2.resize(image,
(int(self.parent().config.window_size[0]), self.parent().config.window_size[1]))
self.parent().annotator.clear()
self.parent().image_label.clear()
self.original_image = image.copy()
self.parent().set_image(image)
if os.path.exists(mask) and os.path.exists(labels):
self._load_annotation(mask, labels)
self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
elif os.path.exists(bounding_boxes):
self._load_bounding_boxes(bounding_boxes)
self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
self.parent().info_label.setText(f"No annotation found! Image: {file}")
self.tag_text_field.setText("")
def _load_annotation(self, mask, labels):
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]),
interpolation=cv2.INTER_NEAREST)
with open(labels, "r") as fp:
labels: dict[str, str] = json.load(fp)
masks = []
new_labels = []
if "instances" in labels:
instance_labels = labels["instances"]
else:
instance_labels = labels
if "tags" in labels:
self.tag_text_field.setText(",".join(labels["tags"]))
else:
self.tag_text_field.setText("")
for str_index, class_ in instance_labels.items():
single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
single_mask[mask == int(str_index)] = 255
masks.append(single_mask)
new_labels.append(class_)
self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels)
def _load_bounding_boxes(self, bounding_boxes):
with open(bounding_boxes, "r") as f:
bounding_boxes: list[dict[str, float | str]] = json.load(f)
for bounding_box in bounding_boxes:
self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box))
def on_show_image(self):
self.parent().set_image(self.original_image, clear_annotations=False)
def on_show_visualization(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_precompute(self):
pass
def on_save_mask(self):
path = os.path.split(self.actual_file)[0]
tags = self.tag_text_field.text().split(",")
tags = [tag.strip() for tag in tags]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
masks = self.parent().get_mask()
labels = {"instances": self.parent().get_labels(), "tags": tags}
with open(labels_path, "w") as f:
json.dump(labels, f, indent=4)
masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST)
cv2.imwrite(mask_path, masks)
def on_checkpoint_path_changed(self):
self.parent().sam = self.parent().init_sam()
def on_open_files(self):
files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)")
random.shuffle(files)
self.files.add_files(files)
self.on_next_file()
def on_save_bounding_boxes(self):
path = os.path.split(self.actual_file)[0]
basename = pathlib.Path(self.actual_file).stem
bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION)
bounding_boxes = self.parent().get_bounding_boxes()
bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes]
with open(bounding_boxes_path, "w") as f:
json.dump(bounding_boxes_dict, f, indent=4)
def is_show_bounding_boxes(self):
return self.show_bounding_boxes.isChecked()
def on_show_bounding_boxes(self):
self.parent().update(self.parent().annotator.merge_image_visualization())

View File

View File

@@ -0,0 +1,55 @@
import dataclasses
import numpy as np
@dataclasses.dataclass
class BoundingBox:
x_min: float
y_min: float
x_max: float
y_max: float
label: str
mask_uid: str = ""
def to_dict(self):
return {
"x_min": self.x_min,
"y_min": self.y_min,
"x_max": self.x_max,
"y_max": self.y_max,
"label": self.label,
"mask_uid": self.mask_uid
}
@property
def center(self):
return np.array([(self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2])
def distance_to(self, point: np.ndarray):
return np.linalg.norm(self.center - point)
def contains(self, point: np.ndarray):
return self.x_min <= point[0] <= self.x_max and self.y_min <= point[1] <= self.y_max
def get_mask_bounding_box(mask, label: str):
where = np.where(mask)
x_min = np.min(where[1])
y_min = np.min(where[0])
x_max = np.max(where[1])
y_max = np.max(where[0])
return BoundingBox(
x_min / mask.shape[1],
y_min / mask.shape[0],
x_max / mask.shape[1],
y_max / mask.shape[0],
label
)
def get_bounding_boxes(masks, labels):
bounding_boxes = []
for mask, label in zip(masks, labels):
bounding_box = get_mask_bounding_box(mask, label)
bounding_boxes.append(bounding_box)
return bounding_boxes

View File

@@ -0,0 +1,26 @@
import glob
import os
import cv2
import numpy as np
import torch
import rich
from PIL import Image
import safetensors
from segment_anything import sam_model_registry
from segment_anything_ui.modeling.storable_sam import StorableSam
from segment_anything_ui.config import Config
config = Config()
sam = sam_model_registry[config.get_sam_model_name()](checkpoint=config.default_weights)
allowed_extensions = [".jpg", ".png", ".tif", ".tiff"]
def load_images_from_folder(folder_path):
images = []
for filename in os.listdir(folder_path):
allowed_extensions = [".jpg", ".png"]
if any(filename.endswith(ext) for ext in allowed_extensions):
img = Image.open(os.path.join(folder_path, filename))
return images

View File

@@ -0,0 +1,53 @@
import dataclasses
import cv2
import numpy as np
from PySide6.QtCore import QPoint
from PySide6.QtGui import QPolygon
@dataclasses.dataclass
class BoundingBox:
xstart: float | int
ystart: float | int
xend: float | int = -1.
yend: float | int = -1.
def to_numpy(self):
return np.array([self.xstart, self.ystart, self.xend, self.yend])
def scale(self, sx, sy):
return BoundingBox(
xstart=self.xstart * sx,
ystart=self.ystart * sy,
xend=self.xend * sx,
yend=self.yend * sy
)
def to_int(self):
return BoundingBox(
xstart=int(self.xstart),
ystart=int(self.ystart),
xend=int(self.xend),
yend=int(self.yend)
)
@dataclasses.dataclass
class Polygon:
points: list = dataclasses.field(default_factory=list)
def to_numpy(self):
return np.array(self.points).reshape(-1, 2)
def to_mask(self, num_rows, num_cols):
mask = np.zeros((num_rows, num_cols))
mask = cv2.fillPoly(mask, pts=[self.to_numpy(), ], color=255)
return mask
def is_plotable(self):
return len(self.points) > 3
def to_qpolygon(self):
return QPolygon([
QPoint(x, y) for x, y in self.points
])

View File