initial_tune
This commit is contained in:
0
segment_anything_ui/__init__.py
Normal file
0
segment_anything_ui/__init__.py
Normal file
218
segment_anything_ui/annotation_layout.py
Normal file
218
segment_anything_ui/annotation_layout.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
|
||||
|
||||
from segment_anything_ui.draw_label import PaintType
|
||||
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
|
||||
|
||||
|
||||
class MergeState(enum.Enum):
|
||||
PICKING = enum.auto()
|
||||
MERGING = enum.auto()
|
||||
|
||||
|
||||
class AnnotationLayout(QWidget):
|
||||
|
||||
def __init__(self, parent, config) -> None:
|
||||
super().__init__(parent)
|
||||
self.config = config
|
||||
self.zoom_flag = False
|
||||
self.merge_state = MergeState.PICKING
|
||||
self.layout = QVBoxLayout(self)
|
||||
labels = self._load_labels(config)
|
||||
self.layout.setAlignment(Qt.AlignTop)
|
||||
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
|
||||
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
|
||||
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
|
||||
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
|
||||
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
|
||||
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
|
||||
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
|
||||
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
|
||||
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
|
||||
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
|
||||
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
|
||||
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
||||
self.label_picker = QListWidget()
|
||||
self.label_picker.addItems(labels)
|
||||
self.label_picker.setCurrentRow(0)
|
||||
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
|
||||
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
|
||||
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
|
||||
self.remove_hidden_masks_line = QLineEdit("0.5")
|
||||
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
|
||||
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
|
||||
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
|
||||
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
|
||||
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
|
||||
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
|
||||
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
|
||||
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
|
||||
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
|
||||
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
|
||||
|
||||
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
|
||||
for w in [
|
||||
self.add_point,
|
||||
self.add_box,
|
||||
self.annotate_all,
|
||||
self.pick_mask,
|
||||
self.pick_bounding_box,
|
||||
self.merge_masks,
|
||||
self.move_current_mask_background,
|
||||
self.cancel_annotation,
|
||||
self.delete_mask,
|
||||
self.partial_annotation,
|
||||
self.save_annotation,
|
||||
self.manual_polygon,
|
||||
self.label_picker,
|
||||
self.zoom_rectangle,
|
||||
self.annotation_settings,
|
||||
self.remove_hidden_masks,
|
||||
self.remove_hidden_masks_label,
|
||||
self.remove_hidden_masks_line
|
||||
]:
|
||||
self.layout.addWidget(w)
|
||||
self.add_point.clicked.connect(self.on_add_point)
|
||||
self.add_box.clicked.connect(self.on_add_box)
|
||||
self.annotate_all.clicked.connect(self.on_annotate_all)
|
||||
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
|
||||
self.save_annotation.clicked.connect(self.on_save_annotation)
|
||||
self.pick_mask.clicked.connect(self.on_pick_mask)
|
||||
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
|
||||
self.manual_polygon.clicked.connect(self.on_manual_polygon)
|
||||
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
|
||||
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
|
||||
self.merge_masks.clicked.connect(self.on_merge_masks)
|
||||
self.partial_annotation.clicked.connect(self.on_partial_annotation)
|
||||
self.delete_mask.clicked.connect(self.on_delete_mask)
|
||||
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
|
||||
|
||||
def on_delete_mask(self):
|
||||
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
|
||||
self.parent().info_label.setText("Deleting mask!")
|
||||
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
|
||||
self.parent().annotator.bounding_boxes.remove(mask_uid)
|
||||
self.parent().annotator.masks.mask_id = -1
|
||||
self.parent().annotator.last_mask = None
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
|
||||
self.parent().info_label.setText("Deleting bounding box!")
|
||||
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
|
||||
self.parent().annotator.bounding_boxes.bounding_box_id)
|
||||
if mask_uid is not None:
|
||||
self.parent().annotator.masks.pop_by_uuid(mask_uid)
|
||||
self.parent().annotator.bounding_boxes.bounding_box_id = -1
|
||||
self.parent().annotator.last_mask = None
|
||||
self.parent().annotator.masks.mask_id = -1
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
else:
|
||||
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
|
||||
|
||||
def on_partial_annotation(self):
|
||||
self.parent().info_label.setText("Partial annotation!")
|
||||
self.parent().annotator.pick_partial_mask()
|
||||
self.parent().image_label.clear()
|
||||
|
||||
@staticmethod
|
||||
def _load_labels(config):
|
||||
if not os.path.exists(config.label_file):
|
||||
return ["default"]
|
||||
with open(config.label_file, "r") as f:
|
||||
labels = json.load(f)
|
||||
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
|
||||
return labels
|
||||
|
||||
def on_merge_masks(self):
|
||||
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
||||
if self.merge_state == MergeState.PICKING:
|
||||
self.parent().info_label.setText("Pick a mask to merge with!")
|
||||
self.merge_state = MergeState.MERGING
|
||||
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
|
||||
elif self.merge_state == MergeState.MERGING:
|
||||
self.parent().info_label.setText("Merging masks!")
|
||||
self.parent().annotator.merge_masks()
|
||||
self.merge_state = MergeState.PICKING
|
||||
|
||||
def on_move_current_mask_background_fn(self):
|
||||
self.parent().info_label.setText("Moving current mask to background!")
|
||||
self.parent().annotator.move_current_mask_to_background()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_remove_hidden_masks(self):
|
||||
self.parent().info_label.setText("Removing hidden masks!")
|
||||
annotations = self.parent().annotator.masks
|
||||
argmax_mask = self.parent().annotator.make_instance_mask()
|
||||
limit_ratio = float(self.remove_hidden_masks_line.text())
|
||||
new_masks = []
|
||||
new_labels = []
|
||||
for idx, (mask, label) in enumerate(annotations):
|
||||
num_pixels = np.sum(mask > 0)
|
||||
num_visible = np.sum(argmax_mask == (idx + 1))
|
||||
ratio = num_visible / num_pixels
|
||||
|
||||
if ratio > limit_ratio:
|
||||
new_masks.append(mask)
|
||||
new_labels.append(label)
|
||||
|
||||
print("Removed ", len(annotations) - len(new_masks), " masks.")
|
||||
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_pick_mask(self):
|
||||
self.parent().info_label.setText("Pick a mask to do modifications!")
|
||||
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
||||
|
||||
def on_pick_bounding_box(self):
|
||||
self.parent().info_label.setText("Pick a bounding box to do modifications!")
|
||||
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
|
||||
|
||||
def on_manual_polygon(self):
|
||||
# Sets emphasis on the button
|
||||
self.manual_polygon.setProperty("active", True)
|
||||
self.parent().image_label.change_paint_type(PaintType.POLYGON)
|
||||
|
||||
def on_add_point(self):
|
||||
self.parent().info_label.setText("Adding point annotation!")
|
||||
self.parent().image_label.change_paint_type(PaintType.POINT)
|
||||
|
||||
def on_add_box(self):
|
||||
self.parent().info_label.setText("Adding box annotation!")
|
||||
self.parent().image_label.change_paint_type(PaintType.BOX)
|
||||
|
||||
def on_zoom_rectangle(self):
|
||||
if self.zoom_flag:
|
||||
self.parent().info_label.setText("Zooming rectangle OFF!")
|
||||
self.parent().image_label.change_paint_type(PaintType.POINT)
|
||||
self.parent().annotator.zoomed_bounding_box = None
|
||||
self.parent().annotator.make_embedding()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.zoom_flag = False
|
||||
else:
|
||||
self.parent().info_label.setText("Pick Mask to zoom!")
|
||||
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
||||
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
|
||||
self.zoom_flag = True
|
||||
|
||||
def on_annotate_all(self):
|
||||
self.parent().info_label.setText("Annotating all. This make take some time.")
|
||||
self.parent().annotator.predict_all(self.annotation_settings.get_values())
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.parent().info_label.setText("Annotate all finished.")
|
||||
|
||||
def on_cancel_annotation(self):
|
||||
self.parent().image_label.clear()
|
||||
self.parent().annotator.clear_last_masks()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_save_annotation(self):
|
||||
if self.parent().image_label.paint_type == PaintType.POLYGON:
|
||||
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
|
||||
self.config.window_size[0], self.config.window_size[1])
|
||||
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.parent().image_label.clear()
|
||||
|
||||
445
segment_anything_ui/annotator.py
Normal file
445
segment_anything_ui/annotator.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import dataclasses
|
||||
from typing import Callable
|
||||
import uuid
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
|
||||
from segment_anything import SamPredictor
|
||||
from segment_anything.build_sam import Sam
|
||||
from segment_anything_ui.model_builder import (
|
||||
get_predictor, get_mask_generator, SamPredictor)
|
||||
try:
|
||||
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
class EfficientViTSamPredictor:
|
||||
pass
|
||||
|
||||
class EfficientViTSam:
|
||||
pass
|
||||
|
||||
from skimage.measure import regionprops
|
||||
import torch
|
||||
|
||||
from segment_anything_ui.utils.shapes import BoundingBox
|
||||
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
|
||||
|
||||
def get_cmap(n, name='hsv'):
|
||||
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
|
||||
RGB color; the keyword argument name must be a standard mpl colormap name.'''
|
||||
try:
|
||||
return plt.cm.get_cmap(name, n)
|
||||
except:
|
||||
return plt.get_cmap(name, n)
|
||||
|
||||
def crop_image(
|
||||
image,
|
||||
box: BoundingBox | None = None,
|
||||
image_shape: tuple[int, int] | None = None
|
||||
):
|
||||
if image_shape is None:
|
||||
image_shape = image.shape[:2][::-1]
|
||||
if box is None:
|
||||
return cv2.resize(image, image_shape)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
|
||||
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
|
||||
|
||||
|
||||
def insert_image(image, box: BoundingBox | None = None):
|
||||
new_image = np.zeros_like(image)
|
||||
if box is None:
|
||||
new_image = image
|
||||
else:
|
||||
if len(image.shape) == 2:
|
||||
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
|
||||
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
||||
else:
|
||||
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
|
||||
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
||||
return new_image
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class AutomaticMaskGeneratorSettings:
|
||||
points_per_side: int = 32
|
||||
pred_iou_thresh: float = 0.88
|
||||
stability_score_thresh: float = 0.95
|
||||
stability_score_offset: float = 1.0
|
||||
box_nms_thresh: float = 0.7
|
||||
crop_n_layers: int = 0
|
||||
crop_nms_thresh: float = 0.7
|
||||
|
||||
|
||||
class LabelValueParam(QWidget):
|
||||
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
|
||||
super().__init__(parent)
|
||||
self.layout = QVBoxLayout(self)
|
||||
self.layout.setSpacing(0)
|
||||
self.layout.setContentsMargins(0, 0, 0, 0)
|
||||
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
|
||||
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
|
||||
self.layout.addWidget(self.label)
|
||||
self.layout.addWidget(self.value)
|
||||
self.converter = value_type_converter
|
||||
|
||||
def get_value(self):
|
||||
return self.converter(self.value.text())
|
||||
|
||||
|
||||
class CustomForm(QWidget):
|
||||
|
||||
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
|
||||
super().__init__(parent)
|
||||
self.layout = QVBoxLayout(self)
|
||||
self.layout.setSpacing(0)
|
||||
self.layout.setContentsMargins(0, 0, 0, 0)
|
||||
self.widgets = []
|
||||
|
||||
for field in dataclasses.fields(automatic_mask_generator_settings):
|
||||
widget = LabelValueParam(field.name, str(field.default), field.type)
|
||||
self.widgets.append(widget)
|
||||
self.layout.addWidget(widget)
|
||||
|
||||
def get_values(self):
|
||||
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
|
||||
|
||||
|
||||
class BoundingBoxAnnotation:
|
||||
def __init__(self) -> None:
|
||||
self.bounding_boxes: list[BoundingBox] = []
|
||||
self.bounding_box_id: int = -1
|
||||
|
||||
def append(self, bounding_box: BoundingBox):
|
||||
self.bounding_boxes.append(bounding_box)
|
||||
|
||||
def find_closest_bounding_box(self, point: np.ndarray):
|
||||
closest_bounding_box = None
|
||||
closest_bounding_box_id = -1
|
||||
min_distance = float('inf')
|
||||
for idx, bounding_box in enumerate(self.bounding_boxes):
|
||||
distance = bounding_box.distance_to(point)
|
||||
if distance < min_distance and bounding_box.contains(point):
|
||||
min_distance = distance
|
||||
closest_bounding_box = bounding_box
|
||||
closest_bounding_box_id = idx
|
||||
self.bounding_box_id = closest_bounding_box_id
|
||||
return closest_bounding_box, closest_bounding_box_id
|
||||
|
||||
def get_bounding_box(self, bounding_box_id: int):
|
||||
return self.bounding_boxes[bounding_box_id]
|
||||
|
||||
def get_current_bounding_box(self):
|
||||
return self.bounding_boxes[self.bounding_box_id]
|
||||
|
||||
def set_current_bounding_box(self, bounding_box: BoundingBox):
|
||||
self.bounding_boxes[self.bounding_box_id] = bounding_box
|
||||
|
||||
def remove(self, mask_uid: str):
|
||||
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
|
||||
if bounding_box_id is None:
|
||||
return
|
||||
bounding_box = self.bounding_boxes.pop(bounding_box_id)
|
||||
if self.bounding_box_id >= bounding_box_id:
|
||||
self.bounding_box_id -= 1
|
||||
return bounding_box
|
||||
|
||||
def remove_by_id(self, bounding_box_id: int):
|
||||
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
|
||||
self.remove(mask_uid)
|
||||
return mask_uid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bounding_boxes)
|
||||
|
||||
|
||||
class MasksAnnotation:
|
||||
DEFAULT_LABEL = "default"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.masks = []
|
||||
self.label_map = {}
|
||||
self.masks_uids: list[str] = []
|
||||
self.mask_id: int = -1
|
||||
|
||||
def add_mask(self, mask, label: str | None = None):
|
||||
self.masks.append(mask)
|
||||
self.masks_uids.append(str(uuid.uuid4()))
|
||||
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
|
||||
return self.masks_uids[-1]
|
||||
|
||||
def add_label(self, mask_id: int, label: str):
|
||||
self.label_map[mask_id] = label
|
||||
|
||||
def get_mask(self, mask_id: int):
|
||||
return self.masks[mask_id]
|
||||
|
||||
def get_label(self, mask_id: int):
|
||||
return self.label_map[mask_id]
|
||||
|
||||
def get_current_mask(self):
|
||||
return self.masks[self.mask_id]
|
||||
|
||||
def set_current_mask(self, mask, label: str = None):
|
||||
self.masks[self.mask_id] = mask
|
||||
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
|
||||
|
||||
def __getitem__(self, mask_id: int):
|
||||
return self.get_mask(mask_id)
|
||||
|
||||
def __setitem__(self, mask_id: int, value):
|
||||
self.masks[mask_id] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.masks)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(zip(self.masks, self.label_map.values()))
|
||||
|
||||
def __next__(self):
|
||||
if self.mask_id >= len(self.masks):
|
||||
raise StopIteration
|
||||
return self.masks[self.mask_id]
|
||||
|
||||
def append(self, mask, label: str | None = None):
|
||||
return self.add_mask(mask, label)
|
||||
|
||||
def pop_by_uuid(self, mask_uid: str):
|
||||
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
|
||||
if mask_id is None:
|
||||
return
|
||||
return self.pop(mask_id)
|
||||
|
||||
def pop(self, mask_id: int = -1):
|
||||
_ = self.masks.pop(mask_id)
|
||||
mask_uid = self.masks_uids.pop(mask_id)
|
||||
self.label_map.pop(mask_id + 1)
|
||||
new_label_map = {}
|
||||
for index, value in enumerate(self.label_map.values()):
|
||||
new_label_map[index + 1] = value
|
||||
self.label_map = new_label_map
|
||||
return mask_uid
|
||||
|
||||
@classmethod
|
||||
def from_masks(cls, masks, labels: list[str] | None = None):
|
||||
annotation = cls()
|
||||
if labels is None:
|
||||
labels = [None] * len(masks)
|
||||
for mask, label in zip(masks, labels):
|
||||
annotation.append(mask, label)
|
||||
return annotation
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class Annotator:
|
||||
sam: Sam | EfficientViTSam | None = None
|
||||
embedding: torch.Tensor | None = None
|
||||
image: np.ndarray | None = None
|
||||
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
|
||||
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
|
||||
predictor: SamPredictor | EfficientViTSamPredictor | None = None
|
||||
visualization: np.ndarray | None = None
|
||||
last_mask: np.ndarray | None = None
|
||||
partial_mask: np.ndarray | None = None
|
||||
merged_mask: np.ndarray | None = None
|
||||
parent: QWidget | None = None
|
||||
cmap: plt.cm = None
|
||||
original_image: np.ndarray | None = None
|
||||
zoomed_bounding_box: BoundingBox | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.MAX_MASKS = 10
|
||||
self.cmap = get_cmap(self.MAX_MASKS)
|
||||
|
||||
def set_image(self, image: np.ndarray):
|
||||
self.image = image
|
||||
return self
|
||||
|
||||
def make_embedding(self):
|
||||
if self.sam is None:
|
||||
return
|
||||
self.predictor = get_predictor(self.sam)
|
||||
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
|
||||
|
||||
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
|
||||
generator = get_mask_generator(
|
||||
sam=self.sam,
|
||||
**dataclasses.asdict(settings)
|
||||
)
|
||||
masks = generator.generate(self.image)
|
||||
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
|
||||
label = self.parent.annotation_layout.label_picker.currentItem().text()
|
||||
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
|
||||
self.cmap = get_cmap(len(self.masks))
|
||||
|
||||
def make_prediction(self, annotation: dict):
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=annotation["points"],
|
||||
point_labels=annotation["labels"],
|
||||
box=annotation["bounding_boxes"],
|
||||
multimask_output=False
|
||||
)
|
||||
mask = masks[0]
|
||||
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
|
||||
|
||||
def pick_partial_mask(self):
|
||||
if self.partial_mask is None:
|
||||
self.partial_mask = self.last_mask.copy()
|
||||
else:
|
||||
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
|
||||
self.last_mask = None
|
||||
|
||||
def move_current_mask_to_background(self):
|
||||
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
|
||||
|
||||
def merge_masks(self):
|
||||
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
|
||||
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
|
||||
self.merged_mask = None
|
||||
|
||||
def visualize_last_mask(self, label: str | None = None):
|
||||
last_mask = np.zeros_like(self.image)
|
||||
last_mask[:, :, 1] = self.last_mask
|
||||
if self.partial_mask is not None:
|
||||
last_mask[:, :, 0] = self.partial_mask
|
||||
if self.merged_mask is not None:
|
||||
last_mask[:, :, 2] = self.merged_mask
|
||||
if label is not None:
|
||||
props = regionprops(self.last_mask)[0]
|
||||
cv2.putText(
|
||||
last_mask,
|
||||
label,
|
||||
(int(props.centroid[1]), int(props.centroid[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
if self.is_show_bounding_boxes:
|
||||
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
|
||||
cv2.rectangle(
|
||||
last_mask,
|
||||
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
||||
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
|
||||
(0, 255, 0),
|
||||
2
|
||||
)
|
||||
cv2.putText(
|
||||
last_mask,
|
||||
label,
|
||||
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
|
||||
0.8, last_mask, 0.5, 0)
|
||||
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
|
||||
|
||||
def _visualize_mask(self) -> tuple:
|
||||
mask_argmax = self.make_instance_mask()
|
||||
visualization = np.zeros_like(self.image)
|
||||
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
|
||||
for i in range(1, np.amax(mask_argmax) + 1):
|
||||
color = self.cmap(i)
|
||||
single_mask = np.zeros_like(mask_argmax)
|
||||
single_mask[mask_argmax == i] = 1
|
||||
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
|
||||
border += single_mask - cv2.erode(
|
||||
single_mask, np.ones((3, 3), np.uint8), iterations=1)
|
||||
label = self.masks.get_label(i)
|
||||
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
|
||||
if np.isnan(single_mask_center).any():
|
||||
continue
|
||||
if self.parent.settings.is_show_text():
|
||||
cv2.putText(
|
||||
visualization,
|
||||
label,
|
||||
(int(single_mask_center[1]), int(single_mask_center[0])),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
0.5,
|
||||
[255, 255, 255],
|
||||
1
|
||||
)
|
||||
if self.is_show_bounding_boxes:
|
||||
bounding_boxes = self.get_bounding_boxes()
|
||||
for idx, bounding_box in enumerate(bounding_boxes):
|
||||
cv2.rectangle(
|
||||
visualization,
|
||||
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
||||
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
|
||||
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
|
||||
2
|
||||
)
|
||||
cv2.putText(
|
||||
visualization,
|
||||
bounding_box.label,
|
||||
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
border = (border == 0).astype(np.uint8)
|
||||
return visualization, border
|
||||
|
||||
def has_annotations(self):
|
||||
return len(self.masks) > 0
|
||||
|
||||
def make_instance_mask(self):
|
||||
background = np.zeros_like(self.masks[0]) + 1
|
||||
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
|
||||
return mask_argmax
|
||||
|
||||
def get_bounding_boxes(self):
|
||||
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
|
||||
|
||||
def merge_image_visualization(self):
|
||||
image = self.image.copy()
|
||||
if not len(self.masks):
|
||||
return crop_image(image, self.zoomed_bounding_box)
|
||||
visualization, border = self._visualize_mask()
|
||||
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
|
||||
return crop_image(self.visualization, self.zoomed_bounding_box)
|
||||
|
||||
def remove_last_mask(self):
|
||||
mask_id = len(self.masks)
|
||||
self.masks.pop(mask_id)
|
||||
self.bounding_boxes.remove(mask_id)
|
||||
|
||||
def make_labels(self):
|
||||
return self.masks.label_map
|
||||
|
||||
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
|
||||
if self.partial_mask is not None:
|
||||
last_mask = self.partial_mask
|
||||
self.partial_mask = None
|
||||
else:
|
||||
last_mask = self.last_mask
|
||||
mask_uid = self.masks.add_mask(last_mask, label=label)
|
||||
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
|
||||
corresponding_bounding_box.mask_uid = mask_uid
|
||||
self.bounding_boxes.append(corresponding_bounding_box)
|
||||
if len(self.masks) >= self.MAX_MASKS:
|
||||
self.MAX_MASKS += 10
|
||||
self.cmap = get_cmap(self.MAX_MASKS)
|
||||
|
||||
@property
|
||||
def is_show_bounding_boxes(self):
|
||||
return self.parent.settings.is_show_bounding_boxes()
|
||||
|
||||
def clear_last_masks(self):
|
||||
self.last_mask = None
|
||||
self.partial_mask = None
|
||||
self.visualization = None
|
||||
|
||||
def clear(self):
|
||||
self.last_mask = None
|
||||
self.visualization = None
|
||||
self.masks = MasksAnnotation()
|
||||
self.bounding_boxes = BoundingBoxAnnotation()
|
||||
self.partial_mask = None
|
||||
143
segment_anything_ui/config.py
Normal file
143
segment_anything_ui/config.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from PySide6.QtCore import Qt
|
||||
import requests
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
import wget
|
||||
except ImportError:
|
||||
print("Tqdm and wget not found. Install with pip install tqdm wget")
|
||||
tqdm = None
|
||||
wget = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Keymap:
|
||||
key: Qt.Key | str
|
||||
name: str
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self):
|
||||
self.progress_bar = None
|
||||
|
||||
def __call__(self, current_bytes, total_bytes, width):
|
||||
current_mb = round(current_bytes / 1024 ** 2, 1)
|
||||
total_mb = round(total_bytes / 1024 ** 2, 1)
|
||||
if self.progress_bar is None:
|
||||
self.progress_bar = tqdm(total=total_mb, desc="MB")
|
||||
delta_mb = current_mb - self.progress_bar.n
|
||||
self.progress_bar.update(delta_mb)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyBindings:
|
||||
ADD_POINT: Keymap = Keymap(Qt.Key.Key_W, "W")
|
||||
ADD_BOX: Keymap = Keymap(Qt.Key.Key_Q, "Q")
|
||||
ANNOTATE_ALL: Keymap = Keymap(Qt.Key.Key_Return, "Enter")
|
||||
MANUAL_POLYGON: Keymap = Keymap(Qt.Key.Key_R, "R")
|
||||
CANCEL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_C, "C")
|
||||
SAVE_ANNOTATION: Keymap = Keymap(Qt.Key.Key_S, "S")
|
||||
PICK_MASK: Keymap = Keymap(Qt.Key.Key_X, "X")
|
||||
PICK_BOUNDING_BOX: Keymap = Keymap(Qt.Key.Key_B, "B")
|
||||
MERGE_MASK: Keymap = Keymap(Qt.Key.Key_Z, "Z")
|
||||
DELETE_MASK: Keymap = Keymap(Qt.Key.Key_V, "V")
|
||||
PARTIAL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_D, "D")
|
||||
SAVE_BOUNDING_BOXES: Keymap = Keymap("Ctrl+B", "Ctrl+B")
|
||||
NEXT_FILE: Keymap = Keymap(Qt.Key.Key_F, "F")
|
||||
PREVIOUS_FILE: Keymap = Keymap(Qt.Key.Key_G, "G")
|
||||
SAVE_MASK: Keymap = Keymap("Ctrl+S", "Ctrl+S")
|
||||
PRECOMPUTE: Keymap = Keymap(Qt.Key.Key_P, "P")
|
||||
ZOOM_RECTANGLE: Keymap = Keymap(Qt.Key.Key_E, "E")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
default_weights: Literal[
|
||||
"sam_vit_b_01ec64.pth",
|
||||
"sam_vit_h_4b8939.pth",
|
||||
"sam_vit_l_0b3195.pth",
|
||||
"xl0.pt",
|
||||
"xl1.pt",
|
||||
"sam_hq_vit_b.pth",
|
||||
"sam_hq_vit_l.pth",
|
||||
"sam_hq_vit_h.pth",
|
||||
"sam_hq_vit_tiny.pth",
|
||||
"sam2.1_hiera_t.pth",
|
||||
"sam2.1_hiera_l.pth",
|
||||
"sam2.1_hiera_b+.pth",
|
||||
"sam2.1_hiera_s.pth",
|
||||
] = "sam_vit_h_4b8939.pth"
|
||||
download_weights_if_not_available: bool = True
|
||||
label_file: str = "labels.json"
|
||||
window_size: tuple[int, int] | int = (1920, 1080)
|
||||
key_mapping: KeyBindings = dataclasses.field(default_factory=KeyBindings)
|
||||
weights_paths: dict[str, str] = dataclasses.field(default_factory=lambda: {
|
||||
"l2": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt",
|
||||
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
||||
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||||
"xl0": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt",
|
||||
"xl1": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt",
|
||||
"hq_vit_b": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
|
||||
"hq_vit_l": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
|
||||
"hq_vit_h": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
|
||||
"hq_vit_tiny": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
|
||||
"sam2.1_hiera_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
||||
"sam2.1_hiera_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
||||
"sam2.1_hiera_b+": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
||||
"sam2.1_hiera_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
||||
})
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.window_size, int):
|
||||
self.window_size = (self.window_size, self.window_size)
|
||||
if self.download_weights_if_not_available:
|
||||
self.download_weights()
|
||||
|
||||
def get_sam_model_name(self):
|
||||
if "l2" in self.default_weights:
|
||||
return "l2"
|
||||
if "sam_vit_b" in self.default_weights:
|
||||
return "vit_b"
|
||||
if "sam_vit_h" in self.default_weights:
|
||||
return "vit_h"
|
||||
if "sam_vit_l" in self.default_weights:
|
||||
return "vit_l"
|
||||
if "xl0" in self.default_weights:
|
||||
return "xl0"
|
||||
if "xl1" in self.default_weights:
|
||||
return "xl1"
|
||||
if "hq_vit_b" in self.default_weights:
|
||||
return "hq_vit_b"
|
||||
if "hq_vit_l" in self.default_weights:
|
||||
return "hq_vit_l"
|
||||
if "hq_vit_h" in self.default_weights:
|
||||
return "hq_vit_h"
|
||||
if "hq_vit_tiny" in self.default_weights:
|
||||
return "hq_vit_tiny"
|
||||
if "sam2.1_hiera_t" in self.default_weights:
|
||||
return "sam2.1_hiera_t"
|
||||
if "sam2.1_hiera_l" in self.default_weights:
|
||||
return "sam2.1_hiera_l"
|
||||
if "sam2.1_hiera_b+" in self.default_weights:
|
||||
return "sam2.1_hiera_b+"
|
||||
if "sam2.1_hiera_s" in self.default_weights:
|
||||
return "sam2.1_hiera_s"
|
||||
raise ValueError("Unknown model name")
|
||||
def download_weights(self):
|
||||
if not os.path.exists(self.default_weights):
|
||||
try:
|
||||
print(f"Downloading weights for model {self.get_sam_model_name()}")
|
||||
wget.download(self.weights_paths[self.get_sam_model_name()], self.default_weights, bar=ProgressBar())
|
||||
print(f"Downloaded weights to {self.default_weights}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading weights: {e}. Trying with requests.")
|
||||
model_name = self.get_sam_model_name()
|
||||
print(f"Downloading weights for model {model_name}")
|
||||
file = requests.get(self.weights_paths[model_name])
|
||||
with open(self.default_weights, "wb") as f:
|
||||
f.write(file.content)
|
||||
print(f"Downloaded weights to {self.default_weights}")
|
||||
268
segment_anything_ui/draw_label.py
Normal file
268
segment_anything_ui/draw_label.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PySide6
|
||||
from PySide6 import QtCore, QtWidgets
|
||||
from PySide6.QtGui import QPainter, QPen
|
||||
|
||||
from segment_anything_ui.config import Config
|
||||
from segment_anything_ui.utils.shapes import BoundingBox, Polygon
|
||||
|
||||
|
||||
class PaintType(Enum):
|
||||
POINT = 0
|
||||
BOX = 1
|
||||
MASK = 2
|
||||
POLYGON = 3
|
||||
MASK_PICKER = 4
|
||||
ZOOM_PICKER = 5
|
||||
BOX_PICKER = 6
|
||||
|
||||
|
||||
class MaskIdPicker:
|
||||
|
||||
def __init__(self, length) -> None:
|
||||
self.counter = 0
|
||||
self.length = length
|
||||
|
||||
def increment(self):
|
||||
self.counter = (self.counter + 1) % self.length
|
||||
|
||||
def pick(self, ids):
|
||||
print("Length of ids: ", len(ids), " counter: ", self.counter, " ids: ", ids)
|
||||
if len(ids) <= self.counter:
|
||||
self.counter = 0
|
||||
return_id = ids[self.counter]
|
||||
self.increment()
|
||||
return return_id
|
||||
|
||||
|
||||
class DrawLabel(QtWidgets.QLabel):
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.positive_points = []
|
||||
self.negative_points = []
|
||||
self.bounding_box = None
|
||||
self.partial_box = BoundingBox(0, 0)
|
||||
self._paint_type = PaintType.POINT
|
||||
self.polygon = Polygon()
|
||||
self.mask_enum: MaskIdPicker = MaskIdPicker(3)
|
||||
self.config = Config()
|
||||
self.setFocusPolicy(QtCore.Qt.StrongFocus)
|
||||
self._zoom_center = (0, 0)
|
||||
self._zoom_factor = 1.0
|
||||
self._zoom_bounding_box: BoundingBox | None = None
|
||||
|
||||
def paintEvent(self, paint_event):
|
||||
painter = QPainter(self)
|
||||
painter.drawPixmap(self.rect(), self.pixmap())
|
||||
pen_positive = self._get_pen(QtCore.Qt.green, 5)
|
||||
pen_negative = self._get_pen(QtCore.Qt.red, 5)
|
||||
pen_partial = self._get_pen(QtCore.Qt.yellow, 1)
|
||||
pen_box = self._get_pen(QtCore.Qt.green, 1)
|
||||
painter.setRenderHint(QPainter.Antialiasing, False)
|
||||
|
||||
painter.setPen(pen_box)
|
||||
|
||||
if self.bounding_box is not None and self.bounding_box.xend != -1 and self.bounding_box.yend != -1:
|
||||
painter.drawRect(
|
||||
self.bounding_box.xstart,
|
||||
self.bounding_box.ystart,
|
||||
self.bounding_box.xend - self.bounding_box.xstart,
|
||||
self.bounding_box.yend - self.bounding_box.ystart
|
||||
)
|
||||
|
||||
painter.setPen(pen_partial)
|
||||
painter.drawRect(self.partial_box.xstart, self.partial_box.ystart,
|
||||
self.partial_box.xend - self.partial_box.xstart,
|
||||
self.partial_box.yend - self.partial_box.ystart)
|
||||
|
||||
painter.setPen(pen_positive)
|
||||
for pos in self.positive_points:
|
||||
painter.drawPoint(pos)
|
||||
|
||||
painter.setPen(pen_negative)
|
||||
painter.setRenderHint(QPainter.Antialiasing, False)
|
||||
for pos in self.negative_points:
|
||||
painter.drawPoint(pos)
|
||||
|
||||
if self.polygon.is_plotable():
|
||||
painter.setPen(pen_box)
|
||||
painter.setRenderHint(QPainter.Antialiasing, True)
|
||||
painter.drawPolygon(self.polygon.to_qpolygon())
|
||||
# self.update()
|
||||
|
||||
def _get_pen(self, color=QtCore.Qt.red, width=5):
|
||||
pen = QPen()
|
||||
pen.setWidth(width)
|
||||
pen.setColor(color)
|
||||
return pen
|
||||
|
||||
@property
|
||||
def paint_type(self):
|
||||
return self._paint_type
|
||||
|
||||
def change_paint_type(self, paint_type: PaintType):
|
||||
print(f"Changing paint type to {paint_type}")
|
||||
self._paint_type = paint_type
|
||||
|
||||
def mouseMoveEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
|
||||
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
|
||||
self.partial_box = copy.deepcopy(self.bounding_box)
|
||||
self.partial_box.xend = ev.pos().x()
|
||||
self.partial_box.yend = ev.pos().y()
|
||||
self.update()
|
||||
|
||||
if self._paint_type == PaintType.POINT:
|
||||
point = ev.pos()
|
||||
if ev.buttons() == QtCore.Qt.LeftButton:
|
||||
self._move_update(None, point)
|
||||
elif ev.buttons() == QtCore.Qt.RightButton:
|
||||
self._move_update(point, None)
|
||||
else:
|
||||
pass
|
||||
self.update()
|
||||
|
||||
def _move_update(self, temporary_point_negative, temporary_point_positive):
|
||||
annotations = self.get_annotations(temporary_point_positive, temporary_point_negative)
|
||||
self.parent().annotator.make_prediction(annotations)
|
||||
self.parent().annotator.visualize_last_mask()
|
||||
|
||||
def mouseReleaseEvent(self, cursor_event):
|
||||
if self._paint_type == PaintType.POINT:
|
||||
if cursor_event.button() == QtCore.Qt.LeftButton:
|
||||
self.positive_points.append(cursor_event.pos())
|
||||
print(self.size())
|
||||
elif cursor_event.button() == QtCore.Qt.RightButton:
|
||||
self.negative_points.append(cursor_event.pos())
|
||||
# self.chosen_points.append(self.mapFromGlobal(QtGui.QCursor.pos()))
|
||||
elif self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
|
||||
if cursor_event.button() == QtCore.Qt.LeftButton:
|
||||
self.bounding_box.xend = cursor_event.pos().x()
|
||||
self.bounding_box.yend = cursor_event.pos().y()
|
||||
self.partial_box = BoundingBox(-1, -1, -1, -1)
|
||||
|
||||
if not self._paint_type == PaintType.MASK_PICKER and \
|
||||
not self._paint_type == PaintType.ZOOM_PICKER and \
|
||||
not self._paint_type == PaintType.POLYGON and \
|
||||
not self._paint_type == PaintType.BOX_PICKER:
|
||||
self.parent().annotator.make_prediction(self.get_annotations())
|
||||
self.parent().annotator.visualize_last_mask()
|
||||
|
||||
if self._paint_type == PaintType.ZOOM_PICKER:
|
||||
self.parent().annotator.zoomed_bounding_box = self.bounding_box.scale(*self._get_scale()).to_int()
|
||||
self.bounding_box = None
|
||||
self.parent().annotator.make_embedding()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self._paint_type = PaintType.POINT
|
||||
|
||||
self.update()
|
||||
|
||||
def mousePressEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
|
||||
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER] and ev.button() == QtCore.Qt.LeftButton:
|
||||
self.bounding_box = BoundingBox(xstart=ev.pos().x(), ystart=ev.pos().y())
|
||||
|
||||
if self._paint_type == PaintType.POLYGON and ev.button() == QtCore.Qt.LeftButton:
|
||||
self.polygon.points.append([ev.pos().x(), ev.pos().y()])
|
||||
|
||||
if self._paint_type == PaintType.MASK_PICKER and ev.button() == QtCore.Qt.LeftButton:
|
||||
size = self.size()
|
||||
point = [
|
||||
int(ev.pos().x() / size.width() * self.config.window_size[0]),
|
||||
int(ev.pos().y() / size.height() * self.config.window_size[1])]
|
||||
masks = np.array(self.parent().annotator.masks.masks)
|
||||
mask_ids = np.where(masks[:, point[1], point[0]])[0]
|
||||
print("Picking mask at point: {}".format(point))
|
||||
if not(len(mask_ids)):
|
||||
print("No mask found")
|
||||
mask_id = -1
|
||||
local_mask = np.zeros((masks.shape[1], masks.shape[2]))
|
||||
label = None
|
||||
else:
|
||||
mask_id = self.mask_enum.pick(mask_ids)
|
||||
local_mask = self.parent().annotator.masks.get_mask(mask_id)
|
||||
label = self.parent().annotator.masks.get_label(mask_id + 1)
|
||||
self.parent().annotator.masks.mask_id = mask_id
|
||||
self.parent().annotator.last_mask = local_mask
|
||||
self.parent().annotator.visualize_last_mask(label)
|
||||
|
||||
if self._paint_type == PaintType.BOX_PICKER and ev.button() == QtCore.Qt.LeftButton:
|
||||
size = self.size()
|
||||
point = [
|
||||
float(ev.pos().x() / size.width()),
|
||||
float(ev.pos().y() / size.height())]
|
||||
bounding_box, bounding_box_id = self.parent().annotator.bounding_boxes.find_closest_bounding_box(point)
|
||||
if bounding_box is None:
|
||||
print("No bounding box found")
|
||||
else:
|
||||
self.parent().annotator.bounding_boxes.bounding_box_id = bounding_box_id
|
||||
print(f"Bounding box: {bounding_box}")
|
||||
print(f"Bounding box id: {bounding_box_id}")
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
|
||||
if self._paint_type == PaintType.POINT:
|
||||
point = ev.pos()
|
||||
if ev.button() == QtCore.Qt.LeftButton:
|
||||
self._move_update(None, point)
|
||||
if ev.button() == QtCore.Qt.RightButton:
|
||||
self._move_update(point, None)
|
||||
self.update()
|
||||
|
||||
def zoom_to_rectangle(self, xstart, ystart, xend, yend):
|
||||
picked_image = self.parent().annotator.image[ystart:yend, xstart:xend, :]
|
||||
self.parent().annotator.image = cv2.resize(picked_image, (self.config.window_size[0], self.config.window_size[1]))
|
||||
self.update()
|
||||
|
||||
def keyPressEvent(self, ev: PySide6.QtGui.QKeyEvent) -> None:
|
||||
print(ev.key())
|
||||
if self._paint_type == PaintType.MASK_PICKER and ev.key() == QtCore.Qt.Key.Key_D and len(self.parent().annotator.masks):
|
||||
print("Deleting mask")
|
||||
self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
|
||||
self.parent().annotator.masks.mask_id = -1
|
||||
self.parent().annotator.last_mask = None
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def _get_scale(self):
|
||||
return self.config.window_size[0] / self.size().width(), self.config.window_size[1] / self.size().height()
|
||||
|
||||
def get_annotations(
|
||||
self,
|
||||
temporary_point_positive: PySide6.QtCore.QPoint | None = None,
|
||||
temporary_point_negative: PySide6.QtCore.QPoint | None = None
|
||||
):
|
||||
sx, sy = self._get_scale()
|
||||
positive_points = [(
|
||||
p.x() * sx,
|
||||
p.y() * sy) for p in self.positive_points]
|
||||
negative_points = [(
|
||||
p.x() * sx,
|
||||
p.y() * sy) for p in self.negative_points]
|
||||
|
||||
if temporary_point_positive:
|
||||
positive_points += [(temporary_point_positive.x() * sx, temporary_point_positive.y() * sy)]
|
||||
if temporary_point_negative:
|
||||
negative_points += [(temporary_point_negative.x() * sx, temporary_point_negative.y() * sy)]
|
||||
|
||||
positive_points = np.array(positive_points).reshape(-1, 2)
|
||||
negative_points = np.array(negative_points).reshape(-1, 2)
|
||||
labels = np.array([1, ] * len(positive_points) + [0, ] * len(negative_points))
|
||||
print(f"Positive points: {positive_points}")
|
||||
print(f"Negative points: {negative_points}")
|
||||
print(f"Labels: {labels}")
|
||||
return {
|
||||
"points": np.concatenate([positive_points, negative_points], axis=0),
|
||||
"labels": labels,
|
||||
"bounding_boxes": self.bounding_box.scale(sx, sy).to_numpy() if self.bounding_box else None
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
self.positive_points = []
|
||||
self.negative_points = []
|
||||
self.bounding_box = None
|
||||
self.partial_box = BoundingBox(0, 0, 0, 0)
|
||||
self.polygon = Polygon()
|
||||
self.update()
|
||||
14
segment_anything_ui/image_pixmap.py
Normal file
14
segment_anything_ui/image_pixmap.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from PySide6.QtGui import QImage, QPixmap, QPainter, QPen
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
|
||||
class ImagePixmap(QPixmap):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def set_image(self, image):
|
||||
if image.dtype == "uint8":
|
||||
image = image.astype("float32") / 255.0
|
||||
image = (image * 255).astype("uint8")
|
||||
image = QImage(image.data, image.shape[1], image.shape[0], QImage.Format_RGB888)
|
||||
self.convertFromImage(image)
|
||||
81
segment_anything_ui/main_window.py
Normal file
81
segment_anything_ui/main_window.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PySide6.QtWidgets import (QApplication, QGridLayout, QLabel,
|
||||
QMessageBox, QWidget)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from segment_anything_ui.annotator import Annotator
|
||||
from segment_anything_ui.annotation_layout import AnnotationLayout
|
||||
from segment_anything_ui.config import Config
|
||||
from segment_anything_ui.draw_label import DrawLabel
|
||||
from segment_anything_ui.image_pixmap import ImagePixmap
|
||||
from segment_anything_ui.model_builder import build_model
|
||||
from segment_anything_ui.settings_layout import SettingsLayout
|
||||
|
||||
|
||||
class SegmentAnythingUI(QWidget):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
self.config: Config = config
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.setWindowTitle("Segment Anything UI")
|
||||
self.setWindowState(Qt.WindowState.WindowMaximized)
|
||||
# self.setGeometry(100, 100, 800, 600)
|
||||
self.layout = QGridLayout(self)
|
||||
self.image_label = DrawLabel(self)
|
||||
self.settings = SettingsLayout(self, config=self.config)
|
||||
self.info_label = QLabel("Information about running process.")
|
||||
self.sam = self.init_sam()
|
||||
self.annotator = Annotator(sam=self.sam, parent=self)
|
||||
self.annotation_layout = AnnotationLayout(self, config=self.config)
|
||||
self.layout.addWidget(self.annotation_layout, 0, 0, 1, 1, Qt.AlignCenter)
|
||||
self.layout.addWidget(self.image_label, 0, 1, 1, 1, Qt.AlignCenter)
|
||||
self.layout.addWidget(self.settings, 0, 3, 1, 1, Qt.AlignCenter)
|
||||
self.layout.addWidget(self.info_label, 1, 1, Qt.AlignBottom)
|
||||
|
||||
self.set_image(np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8))
|
||||
self.show()
|
||||
|
||||
def set_image(self, image: np.ndarray, clear_annotations: bool = True):
|
||||
self.annotator.set_image(image).make_embedding()
|
||||
if clear_annotations:
|
||||
self.annotator.clear()
|
||||
self.update(image)
|
||||
|
||||
def update(self, image: np.ndarray):
|
||||
image = cv2.resize(image, self.config.window_size)
|
||||
pixmap = ImagePixmap()
|
||||
pixmap.set_image(image)
|
||||
print("Updating image")
|
||||
self.image_label.setPixmap(pixmap)
|
||||
|
||||
def init_sam(self):
|
||||
try:
|
||||
checkpoint_path = str(self.settings.checkpoint_path.text())
|
||||
sam = build_model(self.config.get_sam_model_name(), checkpoint_path, self.device)
|
||||
except Exception as e:
|
||||
logging.getLogger().exception(f"Error loading model: {e}")
|
||||
QMessageBox.critical(self, "Error", "Could not load model")
|
||||
return None
|
||||
return sam
|
||||
|
||||
def get_mask(self):
|
||||
return self.annotator.make_instance_mask()
|
||||
|
||||
def get_labels(self):
|
||||
return self.annotator.make_labels()
|
||||
|
||||
def get_bounding_boxes(self):
|
||||
return self.annotator.get_bounding_boxes()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
app = QApplication(sys.argv)
|
||||
ex = SegmentAnythingUI(Config())
|
||||
sys.exit(app.exec())
|
||||
129
segment_anything_ui/model_builder.py
Normal file
129
segment_anything_ui/model_builder.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
from PySide6.QtWidgets import QMessageBox
|
||||
|
||||
try:
|
||||
from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam
|
||||
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator
|
||||
IS_EFFICIENT_VIT_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError) as e:
|
||||
import logging
|
||||
logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
|
||||
IS_EFFICIENT_VIT_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from segment_anything_hq import sam_model_registry as sam_hq_model_registry
|
||||
from segment_anything_hq import SamPredictor as SamPredictorHQ
|
||||
from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq
|
||||
from segment_anything_hq.build_sam import Sam as SamHQ
|
||||
IS_SAM_HQ_AVAILABLE = True
|
||||
_SAM_HQ_MODEL_REGISTRY = {
|
||||
"hq_vit_b": "vit_b",
|
||||
"hq_vit_l": "vit_l",
|
||||
"hq_vit_h": "vit_h",
|
||||
"hq_vit_tiny": "vit_tiny",
|
||||
}
|
||||
except (ModuleNotFoundError, ImportError) as e:
|
||||
import logging
|
||||
logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
|
||||
IS_SAM_HQ_AVAILABLE = False
|
||||
_SAM_HQ_MODEL_REGISTRY = {}
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||
|
||||
IS_SAM2_AVAILABLE = True
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from hydra import initialize
|
||||
|
||||
# Reset Hydra's global configuration
|
||||
if GlobalHydra.instance().is_initialized():
|
||||
GlobalHydra.instance().clear()
|
||||
|
||||
_SAM2_MODEL_REGISTRY = {
|
||||
"sam2.1_hiera_t": "sam2.1_hiera_t.yaml",
|
||||
"sam2.1_hiera_l": "sam2.1_hiera_l.yaml",
|
||||
"sam2.1_hiera_b": "sam2.1_hiera_b+.yaml",
|
||||
"sam2.1_hiera_s": "sam2.1_hiera_s.yaml",
|
||||
}
|
||||
|
||||
except (ModuleNotFoundError, ImportError) as e:
|
||||
import logging
|
||||
logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .")
|
||||
IS_SAM2_AVAILABLE = False
|
||||
_SAM2_MODEL_REGISTRY = {}
|
||||
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything import SamPredictor, automatic_mask_generator
|
||||
from segment_anything.build_sam import Sam
|
||||
|
||||
|
||||
def build_model(model_name: str, checkpoint_path: str, device: str):
|
||||
match model_name:
|
||||
case "xl0" | "xl1":
|
||||
if not IS_EFFICIENT_VIT_AVAILABLE:
|
||||
raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
|
||||
efficientvit_sam = create_efficientvit_sam_model(
|
||||
name=model_name, weight_url=checkpoint_path,
|
||||
)
|
||||
return efficientvit_sam.to(device).eval()
|
||||
|
||||
case "vit_b" | "vit_l" | "vit_h":
|
||||
sam = sam_model_registry[model_name](
|
||||
checkpoint=checkpoint_path)
|
||||
sam.eval()
|
||||
return sam.to(device)
|
||||
|
||||
case "hq_vit_b" | "hq_vit_l" | "hq_vit_h":
|
||||
if not IS_SAM_HQ_AVAILABLE:
|
||||
QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .")
|
||||
raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
|
||||
sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]](
|
||||
checkpoint=checkpoint_path)
|
||||
sam.eval()
|
||||
return sam.to(device)
|
||||
case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s":
|
||||
if not IS_SAM2_AVAILABLE:
|
||||
QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .")
|
||||
raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .")
|
||||
|
||||
with initialize(version_base=None, config_path="sam2_configs"):
|
||||
sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device)
|
||||
sam.eval()
|
||||
return sam
|
||||
case _:
|
||||
raise ValueError(f"Model {model_name} not supported")
|
||||
|
||||
|
||||
def get_predictor(sam):
|
||||
if isinstance(sam, Sam):
|
||||
return SamPredictor(sam)
|
||||
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
|
||||
return EfficientViTSamPredictor(sam)
|
||||
|
||||
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
|
||||
return SamPredictorHQ(sam)
|
||||
|
||||
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
|
||||
return SAM2ImagePredictor(sam)
|
||||
else:
|
||||
raise ValueError("Model is not an EfficientViTSam or Sam")
|
||||
|
||||
|
||||
def get_mask_generator(sam, **kwargs):
|
||||
if isinstance(sam, Sam):
|
||||
return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs)
|
||||
|
||||
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
|
||||
return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs)
|
||||
|
||||
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
|
||||
return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs)
|
||||
|
||||
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
|
||||
return SAM2AutomaticMaskGenerator(model=sam)
|
||||
|
||||
else:
|
||||
raise ValueError("Model is not an EfficientViTSam or Sam")
|
||||
0
segment_anything_ui/modeling/__init__.py
Normal file
0
segment_anything_ui/modeling/__init__.py
Normal file
45
segment_anything_ui/modeling/efficientvit/sam_model_zoo.py
Normal file
45
segment_anything_ui/modeling/efficientvit/sam_model_zoo.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
||||
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
||||
# International Conference on Computer Vision (ICCV), 2023
|
||||
|
||||
from segment_anything_ui.modeling.efficientvit.models.efficientvit import (
|
||||
EfficientViTSam,
|
||||
efficientvit_sam_l0,
|
||||
efficientvit_sam_l1,
|
||||
efficientvit_sam_l2,
|
||||
)
|
||||
from segment_anything_ui.modeling.efficientvit.models.nn.norm import set_norm_eps
|
||||
from segment_anything_ui.modeling.efficientvit.models.utils import load_state_dict_from_file
|
||||
|
||||
__all__ = ["create_sam_model"]
|
||||
|
||||
|
||||
REGISTERED_SAM_MODEL: dict[str, str] = {
|
||||
"l0": "assets/checkpoints/sam/l0.pt",
|
||||
"l1": "assets/checkpoints/sam/l1.pt",
|
||||
"l2": "assets/checkpoints/sam/l2.pt",
|
||||
}
|
||||
|
||||
|
||||
def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam:
|
||||
model_dict = {
|
||||
"l0": efficientvit_sam_l0,
|
||||
"l1": efficientvit_sam_l1,
|
||||
"l2": efficientvit_sam_l2,
|
||||
}
|
||||
|
||||
model_id = name.split("-")[0]
|
||||
if model_id not in model_dict:
|
||||
raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}")
|
||||
else:
|
||||
model = model_dict[model_id](**kwargs)
|
||||
set_norm_eps(model, 1e-6)
|
||||
|
||||
if pretrained:
|
||||
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
|
||||
if weight_url is None:
|
||||
raise ValueError(f"Do not find the pretrained weight of {name}.")
|
||||
else:
|
||||
weight = load_state_dict_from_file(weight_url)
|
||||
model.load_state_dict(weight)
|
||||
return model
|
||||
30
segment_anything_ui/modeling/storable_sam.py
Normal file
30
segment_anything_ui/modeling/storable_sam.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from safetensors import safe_open
|
||||
from segment_anything.modeling import Sam
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModifiedImageEncoder(nn.Module):
|
||||
|
||||
def __init__(self, image_encoder, saved_path: str | None = None):
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
if saved_path is not None:
|
||||
self.embeddings = safe_open(saved_path)
|
||||
else:
|
||||
self.embeddings = None
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_encoder(x) if self.embeddings is None else self.embeddings
|
||||
|
||||
|
||||
class StorableSam:
|
||||
|
||||
def __init__(self, sam):
|
||||
self.sam = sam
|
||||
self.image_encoder = sam.image_encoder
|
||||
|
||||
def transform(self, saved_path):
|
||||
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
|
||||
|
||||
def precompute(self, image):
|
||||
return self.image_encoder(image)
|
||||
116
segment_anything_ui/sam2_configs/sam2.1_hiera_b+.yaml
Normal file
116
segment_anything_ui/sam2_configs/sam2.1_hiera_b+.yaml
Normal file
@@ -0,0 +1,116 @@
|
||||
# @package _global_
|
||||
|
||||
# Model
|
||||
model:
|
||||
_target_: sam2.modeling.sam2_base.SAM2Base
|
||||
image_encoder:
|
||||
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
||||
scalp: 1
|
||||
trunk:
|
||||
_target_: sam2.modeling.backbones.hieradet.Hiera
|
||||
embed_dim: 112
|
||||
num_heads: 2
|
||||
neck:
|
||||
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 256
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
d_model: 256
|
||||
backbone_channel_list: [896, 448, 224, 112]
|
||||
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
||||
fpn_interp_model: nearest
|
||||
|
||||
memory_attention:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttention
|
||||
d_model: 256
|
||||
pos_enc_at_input: true
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
||||
activation: relu
|
||||
dim_feedforward: 2048
|
||||
dropout: 0.1
|
||||
pos_enc_at_attn: false
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
d_model: 256
|
||||
pos_enc_at_cross_attn_keys: true
|
||||
pos_enc_at_cross_attn_queries: false
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
kv_in_dim: 64
|
||||
num_layers: 4
|
||||
|
||||
memory_encoder:
|
||||
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
||||
out_dim: 64
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 64
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
mask_downsampler:
|
||||
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
padding: 1
|
||||
fuser:
|
||||
_target_: sam2.modeling.memory_encoder.Fuser
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_encoder.CXBlock
|
||||
dim: 256
|
||||
kernel_size: 7
|
||||
padding: 3
|
||||
layer_scale_init_value: 1e-6
|
||||
use_dwconv: True # depth-wise convs
|
||||
num_layers: 2
|
||||
|
||||
num_maskmem: 7
|
||||
image_size: 1024
|
||||
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
||||
sigmoid_scale_for_mem_enc: 20.0
|
||||
sigmoid_bias_for_mem_enc: -10.0
|
||||
use_mask_input_as_output_without_sam: true
|
||||
# Memory
|
||||
directly_add_no_mem_embed: true
|
||||
no_obj_embed_spatial: true
|
||||
# use high-resolution feature map in the SAM mask decoder
|
||||
use_high_res_features_in_sam: true
|
||||
# output 3 masks on the first click on initial conditioning frames
|
||||
multimask_output_in_sam: true
|
||||
# SAM heads
|
||||
iou_prediction_use_sigmoid: True
|
||||
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
use_obj_ptrs_in_encoder: true
|
||||
add_tpos_enc_to_obj_ptrs: true
|
||||
proj_tpos_enc_in_obj_ptrs: true
|
||||
use_signed_tpos_enc_to_obj_ptrs: true
|
||||
only_obj_ptrs_in_the_past_for_eval: true
|
||||
# object occlusion prediction
|
||||
pred_obj_scores: true
|
||||
pred_obj_scores_mlp: true
|
||||
fixed_no_obj_ptr: true
|
||||
# multimask tracking settings
|
||||
multimask_output_for_tracking: true
|
||||
use_multimask_token_for_obj_ptr: true
|
||||
multimask_min_pt_num: 0
|
||||
multimask_max_pt_num: 1
|
||||
use_mlp_for_obj_ptr_proj: true
|
||||
# Compilation flag
|
||||
compile_image_encoder: False
|
||||
120
segment_anything_ui/sam2_configs/sam2.1_hiera_l.yaml
Normal file
120
segment_anything_ui/sam2_configs/sam2.1_hiera_l.yaml
Normal file
@@ -0,0 +1,120 @@
|
||||
# @package _global_
|
||||
|
||||
# Model
|
||||
model:
|
||||
_target_: sam2.modeling.sam2_base.SAM2Base
|
||||
image_encoder:
|
||||
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
||||
scalp: 1
|
||||
trunk:
|
||||
_target_: sam2.modeling.backbones.hieradet.Hiera
|
||||
embed_dim: 144
|
||||
num_heads: 2
|
||||
stages: [2, 6, 36, 4]
|
||||
global_att_blocks: [23, 33, 43]
|
||||
window_pos_embed_bkg_spatial_size: [7, 7]
|
||||
window_spec: [8, 4, 16, 8]
|
||||
neck:
|
||||
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 256
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
d_model: 256
|
||||
backbone_channel_list: [1152, 576, 288, 144]
|
||||
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
||||
fpn_interp_model: nearest
|
||||
|
||||
memory_attention:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttention
|
||||
d_model: 256
|
||||
pos_enc_at_input: true
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
||||
activation: relu
|
||||
dim_feedforward: 2048
|
||||
dropout: 0.1
|
||||
pos_enc_at_attn: false
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
d_model: 256
|
||||
pos_enc_at_cross_attn_keys: true
|
||||
pos_enc_at_cross_attn_queries: false
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
kv_in_dim: 64
|
||||
num_layers: 4
|
||||
|
||||
memory_encoder:
|
||||
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
||||
out_dim: 64
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 64
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
mask_downsampler:
|
||||
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
padding: 1
|
||||
fuser:
|
||||
_target_: sam2.modeling.memory_encoder.Fuser
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_encoder.CXBlock
|
||||
dim: 256
|
||||
kernel_size: 7
|
||||
padding: 3
|
||||
layer_scale_init_value: 1e-6
|
||||
use_dwconv: True # depth-wise convs
|
||||
num_layers: 2
|
||||
|
||||
num_maskmem: 7
|
||||
image_size: 1024
|
||||
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
||||
sigmoid_scale_for_mem_enc: 20.0
|
||||
sigmoid_bias_for_mem_enc: -10.0
|
||||
use_mask_input_as_output_without_sam: true
|
||||
# Memory
|
||||
directly_add_no_mem_embed: true
|
||||
no_obj_embed_spatial: true
|
||||
# use high-resolution feature map in the SAM mask decoder
|
||||
use_high_res_features_in_sam: true
|
||||
# output 3 masks on the first click on initial conditioning frames
|
||||
multimask_output_in_sam: true
|
||||
# SAM heads
|
||||
iou_prediction_use_sigmoid: True
|
||||
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
use_obj_ptrs_in_encoder: true
|
||||
add_tpos_enc_to_obj_ptrs: true
|
||||
proj_tpos_enc_in_obj_ptrs: true
|
||||
use_signed_tpos_enc_to_obj_ptrs: true
|
||||
only_obj_ptrs_in_the_past_for_eval: true
|
||||
# object occlusion prediction
|
||||
pred_obj_scores: true
|
||||
pred_obj_scores_mlp: true
|
||||
fixed_no_obj_ptr: true
|
||||
# multimask tracking settings
|
||||
multimask_output_for_tracking: true
|
||||
use_multimask_token_for_obj_ptr: true
|
||||
multimask_min_pt_num: 0
|
||||
multimask_max_pt_num: 1
|
||||
use_mlp_for_obj_ptr_proj: true
|
||||
# Compilation flag
|
||||
compile_image_encoder: False
|
||||
119
segment_anything_ui/sam2_configs/sam2.1_hiera_s.yaml
Normal file
119
segment_anything_ui/sam2_configs/sam2.1_hiera_s.yaml
Normal file
@@ -0,0 +1,119 @@
|
||||
# @package _global_
|
||||
|
||||
# Model
|
||||
model:
|
||||
_target_: sam2.modeling.sam2_base.SAM2Base
|
||||
image_encoder:
|
||||
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
||||
scalp: 1
|
||||
trunk:
|
||||
_target_: sam2.modeling.backbones.hieradet.Hiera
|
||||
embed_dim: 96
|
||||
num_heads: 1
|
||||
stages: [1, 2, 11, 2]
|
||||
global_att_blocks: [7, 10, 13]
|
||||
window_pos_embed_bkg_spatial_size: [7, 7]
|
||||
neck:
|
||||
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 256
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
d_model: 256
|
||||
backbone_channel_list: [768, 384, 192, 96]
|
||||
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
||||
fpn_interp_model: nearest
|
||||
|
||||
memory_attention:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttention
|
||||
d_model: 256
|
||||
pos_enc_at_input: true
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
||||
activation: relu
|
||||
dim_feedforward: 2048
|
||||
dropout: 0.1
|
||||
pos_enc_at_attn: false
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
d_model: 256
|
||||
pos_enc_at_cross_attn_keys: true
|
||||
pos_enc_at_cross_attn_queries: false
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
kv_in_dim: 64
|
||||
num_layers: 4
|
||||
|
||||
memory_encoder:
|
||||
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
||||
out_dim: 64
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 64
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
mask_downsampler:
|
||||
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
padding: 1
|
||||
fuser:
|
||||
_target_: sam2.modeling.memory_encoder.Fuser
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_encoder.CXBlock
|
||||
dim: 256
|
||||
kernel_size: 7
|
||||
padding: 3
|
||||
layer_scale_init_value: 1e-6
|
||||
use_dwconv: True # depth-wise convs
|
||||
num_layers: 2
|
||||
|
||||
num_maskmem: 7
|
||||
image_size: 1024
|
||||
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
||||
sigmoid_scale_for_mem_enc: 20.0
|
||||
sigmoid_bias_for_mem_enc: -10.0
|
||||
use_mask_input_as_output_without_sam: true
|
||||
# Memory
|
||||
directly_add_no_mem_embed: true
|
||||
no_obj_embed_spatial: true
|
||||
# use high-resolution feature map in the SAM mask decoder
|
||||
use_high_res_features_in_sam: true
|
||||
# output 3 masks on the first click on initial conditioning frames
|
||||
multimask_output_in_sam: true
|
||||
# SAM heads
|
||||
iou_prediction_use_sigmoid: True
|
||||
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
use_obj_ptrs_in_encoder: true
|
||||
add_tpos_enc_to_obj_ptrs: true
|
||||
proj_tpos_enc_in_obj_ptrs: true
|
||||
use_signed_tpos_enc_to_obj_ptrs: true
|
||||
only_obj_ptrs_in_the_past_for_eval: true
|
||||
# object occlusion prediction
|
||||
pred_obj_scores: true
|
||||
pred_obj_scores_mlp: true
|
||||
fixed_no_obj_ptr: true
|
||||
# multimask tracking settings
|
||||
multimask_output_for_tracking: true
|
||||
use_multimask_token_for_obj_ptr: true
|
||||
multimask_min_pt_num: 0
|
||||
multimask_max_pt_num: 1
|
||||
use_mlp_for_obj_ptr_proj: true
|
||||
# Compilation flag
|
||||
compile_image_encoder: False
|
||||
121
segment_anything_ui/sam2_configs/sam2.1_hiera_t.yaml
Normal file
121
segment_anything_ui/sam2_configs/sam2.1_hiera_t.yaml
Normal file
@@ -0,0 +1,121 @@
|
||||
# @package _global_
|
||||
|
||||
# Model
|
||||
model:
|
||||
_target_: sam2.modeling.sam2_base.SAM2Base
|
||||
image_encoder:
|
||||
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
||||
scalp: 1
|
||||
trunk:
|
||||
_target_: sam2.modeling.backbones.hieradet.Hiera
|
||||
embed_dim: 96
|
||||
num_heads: 1
|
||||
stages: [1, 2, 7, 2]
|
||||
global_att_blocks: [5, 7, 9]
|
||||
window_pos_embed_bkg_spatial_size: [7, 7]
|
||||
neck:
|
||||
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 256
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
d_model: 256
|
||||
backbone_channel_list: [768, 384, 192, 96]
|
||||
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
||||
fpn_interp_model: nearest
|
||||
|
||||
memory_attention:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttention
|
||||
d_model: 256
|
||||
pos_enc_at_input: true
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
||||
activation: relu
|
||||
dim_feedforward: 2048
|
||||
dropout: 0.1
|
||||
pos_enc_at_attn: false
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
d_model: 256
|
||||
pos_enc_at_cross_attn_keys: true
|
||||
pos_enc_at_cross_attn_queries: false
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
dropout: 0.1
|
||||
kv_in_dim: 64
|
||||
num_layers: 4
|
||||
|
||||
memory_encoder:
|
||||
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
||||
out_dim: 64
|
||||
position_encoding:
|
||||
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: 64
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
mask_downsampler:
|
||||
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
padding: 1
|
||||
fuser:
|
||||
_target_: sam2.modeling.memory_encoder.Fuser
|
||||
layer:
|
||||
_target_: sam2.modeling.memory_encoder.CXBlock
|
||||
dim: 256
|
||||
kernel_size: 7
|
||||
padding: 3
|
||||
layer_scale_init_value: 1e-6
|
||||
use_dwconv: True # depth-wise convs
|
||||
num_layers: 2
|
||||
|
||||
num_maskmem: 7
|
||||
image_size: 1024
|
||||
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
||||
# SAM decoder
|
||||
sigmoid_scale_for_mem_enc: 20.0
|
||||
sigmoid_bias_for_mem_enc: -10.0
|
||||
use_mask_input_as_output_without_sam: true
|
||||
# Memory
|
||||
directly_add_no_mem_embed: true
|
||||
no_obj_embed_spatial: true
|
||||
# use high-resolution feature map in the SAM mask decoder
|
||||
use_high_res_features_in_sam: true
|
||||
# output 3 masks on the first click on initial conditioning frames
|
||||
multimask_output_in_sam: true
|
||||
# SAM heads
|
||||
iou_prediction_use_sigmoid: True
|
||||
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
use_obj_ptrs_in_encoder: true
|
||||
add_tpos_enc_to_obj_ptrs: true
|
||||
proj_tpos_enc_in_obj_ptrs: true
|
||||
use_signed_tpos_enc_to_obj_ptrs: true
|
||||
only_obj_ptrs_in_the_past_for_eval: true
|
||||
# object occlusion prediction
|
||||
pred_obj_scores: true
|
||||
pred_obj_scores_mlp: true
|
||||
fixed_no_obj_ptr: true
|
||||
# multimask tracking settings
|
||||
multimask_output_for_tracking: true
|
||||
use_multimask_token_for_obj_ptr: true
|
||||
multimask_min_pt_num: 0
|
||||
multimask_max_pt_num: 1
|
||||
use_mlp_for_obj_ptr_proj: true
|
||||
# Compilation flag
|
||||
# HieraT does not currently support compilation, should always be set to False
|
||||
compile_image_encoder: False
|
||||
13
segment_anything_ui/saver.py
Normal file
13
segment_anything_ui/saver.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
class Saver:
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
|
||||
def __call__(self, basename, mask_annotation) -> Any:
|
||||
save_path = os.path.join(self.path, basename)
|
||||
# TODO: finish this
|
||||
0
segment_anything_ui/segment_anything_control.py
Normal file
0
segment_anything_ui/segment_anything_control.py
Normal file
228
segment_anything_ui/settings_layout.py
Normal file
228
segment_anything_ui/settings_layout.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox
|
||||
|
||||
from segment_anything_ui.annotator import MasksAnnotation
|
||||
from segment_anything_ui.config import Config
|
||||
from segment_anything_ui.utils.bounding_boxes import BoundingBox
|
||||
|
||||
|
||||
class FilesHolder:
|
||||
def __init__(self):
|
||||
self.files = []
|
||||
self.position = 0
|
||||
|
||||
def add_files(self, files):
|
||||
self.files.extend(files)
|
||||
|
||||
def get_next(self):
|
||||
self.position += 1
|
||||
if self.position >= len(self.files):
|
||||
self.position = 0
|
||||
return self.files[self.position]
|
||||
|
||||
def get_previous(self):
|
||||
self.position -= 1
|
||||
|
||||
if self.position < 0:
|
||||
self.position = len(self.files) - 1
|
||||
return self.files[self.position]
|
||||
|
||||
|
||||
class SettingsLayout(QWidget):
|
||||
MASK_EXTENSION = "_mask.png"
|
||||
LABELS_EXTENSION = "_labels.json"
|
||||
BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json"
|
||||
|
||||
def __init__(self, parent, config: Config) -> None:
|
||||
super().__init__(parent)
|
||||
self.config = config
|
||||
self.actual_file: str = ""
|
||||
self.actual_shape = self.config.window_size
|
||||
self.layout = QVBoxLayout(self)
|
||||
self.open_files = QPushButton("Open Files")
|
||||
self.open_files.clicked.connect(self.on_open_files)
|
||||
self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]")
|
||||
self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]")
|
||||
self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key)
|
||||
self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]")
|
||||
self.save_mask.clicked.connect(self.on_save_mask)
|
||||
self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key)
|
||||
self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]")
|
||||
self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes)
|
||||
self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key)
|
||||
self.next_file.clicked.connect(self.on_next_file)
|
||||
self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key)
|
||||
self.previous_file.clicked.connect(self.on_previous_file)
|
||||
self.checkpoint_path_label = QLabel(self, text="Checkpoint Path")
|
||||
self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights)
|
||||
self.precompute_button = QPushButton("Precompute all embeddings")
|
||||
self.precompute_button.clicked.connect(self.on_precompute)
|
||||
self.delete_existing_annotation = QPushButton("Delete existing annotation")
|
||||
self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation)
|
||||
self.show_image = QPushButton("Show Image")
|
||||
self.show_visualization = QPushButton("Show Visualization")
|
||||
self.show_bounding_boxes = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes)
|
||||
self.show_image.clicked.connect(self.on_show_image)
|
||||
self.show_visualization.clicked.connect(self.on_show_visualization)
|
||||
self.show_text = QCheckBox("Show Text")
|
||||
self.show_text.clicked.connect(self.on_show_text)
|
||||
self.tag_text_field = QLineEdit(self)
|
||||
self.tag_text_field.setPlaceholderText("Comma separated image tags")
|
||||
self.layout.addWidget(self.open_files)
|
||||
self.layout.addWidget(self.next_file)
|
||||
self.layout.addWidget(self.previous_file)
|
||||
self.layout.addWidget(self.save_mask)
|
||||
self.layout.addWidget(self.save_bounding_boxes)
|
||||
self.layout.addWidget(self.delete_existing_annotation)
|
||||
self.layout.addWidget(self.show_text)
|
||||
self.layout.addWidget(self.show_bounding_boxes)
|
||||
self.layout.addWidget(self.tag_text_field)
|
||||
self.layout.addWidget(self.checkpoint_path_label)
|
||||
self.layout.addWidget(self.checkpoint_path)
|
||||
self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed)
|
||||
self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed)
|
||||
self.layout.addWidget(self.precompute_button)
|
||||
self.layout.addWidget(self.show_image)
|
||||
self.layout.addWidget(self.show_visualization)
|
||||
self.files = FilesHolder()
|
||||
self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)
|
||||
|
||||
def on_delete_existing_annotation(self):
|
||||
path = os.path.split(self.actual_file)[0]
|
||||
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
|
||||
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
|
||||
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
|
||||
if os.path.exists(mask_path):
|
||||
os.remove(mask_path)
|
||||
if os.path.exists(labels_path):
|
||||
os.remove(labels_path)
|
||||
|
||||
def is_show_text(self):
|
||||
return self.show_text.isChecked()
|
||||
|
||||
def on_show_text(self):
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_next_file(self):
|
||||
file = self.files.get_next()
|
||||
self._load_image(file)
|
||||
|
||||
def on_previous_file(self):
|
||||
file = self.files.get_previous()
|
||||
self._load_image(file)
|
||||
|
||||
def _load_image(self, file: str):
|
||||
mask = file.split(".")[0] + self.MASK_EXTENSION
|
||||
labels = file.split(".")[0] + self.LABELS_EXTENSION
|
||||
bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION
|
||||
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
||||
self.actual_shape = image.shape[:2][::-1]
|
||||
self.actual_file = file
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
if image.dtype in [np.float32, np.float64, np.uint16]:
|
||||
image = (image / np.amax(image) * 255).astype("uint8")
|
||||
#image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1)
|
||||
image = cv2.resize(image,
|
||||
(int(self.parent().config.window_size[0]), self.parent().config.window_size[1]))
|
||||
self.parent().annotator.clear()
|
||||
self.parent().image_label.clear()
|
||||
self.original_image = image.copy()
|
||||
self.parent().set_image(image)
|
||||
if os.path.exists(mask) and os.path.exists(labels):
|
||||
self._load_annotation(mask, labels)
|
||||
self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}")
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
elif os.path.exists(bounding_boxes):
|
||||
self._load_bounding_boxes(bounding_boxes)
|
||||
self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}")
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
else:
|
||||
self.parent().info_label.setText(f"No annotation found! Image: {file}")
|
||||
self.tag_text_field.setText("")
|
||||
|
||||
def _load_annotation(self, mask, labels):
|
||||
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
||||
mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
with open(labels, "r") as fp:
|
||||
labels: dict[str, str] = json.load(fp)
|
||||
masks = []
|
||||
new_labels = []
|
||||
if "instances" in labels:
|
||||
instance_labels = labels["instances"]
|
||||
else:
|
||||
instance_labels = labels
|
||||
|
||||
if "tags" in labels:
|
||||
self.tag_text_field.setText(",".join(labels["tags"]))
|
||||
else:
|
||||
self.tag_text_field.setText("")
|
||||
for str_index, class_ in instance_labels.items():
|
||||
single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
|
||||
single_mask[mask == int(str_index)] = 255
|
||||
masks.append(single_mask)
|
||||
new_labels.append(class_)
|
||||
self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels)
|
||||
|
||||
def _load_bounding_boxes(self, bounding_boxes):
|
||||
with open(bounding_boxes, "r") as f:
|
||||
bounding_boxes: list[dict[str, float | str]] = json.load(f)
|
||||
for bounding_box in bounding_boxes:
|
||||
self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box))
|
||||
|
||||
def on_show_image(self):
|
||||
self.parent().set_image(self.original_image, clear_annotations=False)
|
||||
|
||||
def on_show_visualization(self):
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_precompute(self):
|
||||
pass
|
||||
|
||||
def on_save_mask(self):
|
||||
path = os.path.split(self.actual_file)[0]
|
||||
tags = self.tag_text_field.text().split(",")
|
||||
tags = [tag.strip() for tag in tags]
|
||||
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
|
||||
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
|
||||
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
|
||||
masks = self.parent().get_mask()
|
||||
labels = {"instances": self.parent().get_labels(), "tags": tags}
|
||||
with open(labels_path, "w") as f:
|
||||
json.dump(labels, f, indent=4)
|
||||
masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST)
|
||||
cv2.imwrite(mask_path, masks)
|
||||
|
||||
def on_checkpoint_path_changed(self):
|
||||
self.parent().sam = self.parent().init_sam()
|
||||
|
||||
def on_open_files(self):
|
||||
files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)")
|
||||
random.shuffle(files)
|
||||
self.files.add_files(files)
|
||||
self.on_next_file()
|
||||
|
||||
def on_save_bounding_boxes(self):
|
||||
path = os.path.split(self.actual_file)[0]
|
||||
basename = pathlib.Path(self.actual_file).stem
|
||||
bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION)
|
||||
bounding_boxes = self.parent().get_bounding_boxes()
|
||||
bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes]
|
||||
with open(bounding_boxes_path, "w") as f:
|
||||
json.dump(bounding_boxes_dict, f, indent=4)
|
||||
|
||||
def is_show_bounding_boxes(self):
|
||||
return self.show_bounding_boxes.isChecked()
|
||||
|
||||
def on_show_bounding_boxes(self):
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
0
segment_anything_ui/utils/__init__.py
Normal file
0
segment_anything_ui/utils/__init__.py
Normal file
55
segment_anything_ui/utils/bounding_boxes.py
Normal file
55
segment_anything_ui/utils/bounding_boxes.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BoundingBox:
|
||||
x_min: float
|
||||
y_min: float
|
||||
x_max: float
|
||||
y_max: float
|
||||
label: str
|
||||
mask_uid: str = ""
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"x_min": self.x_min,
|
||||
"y_min": self.y_min,
|
||||
"x_max": self.x_max,
|
||||
"y_max": self.y_max,
|
||||
"label": self.label,
|
||||
"mask_uid": self.mask_uid
|
||||
}
|
||||
|
||||
@property
|
||||
def center(self):
|
||||
return np.array([(self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2])
|
||||
|
||||
def distance_to(self, point: np.ndarray):
|
||||
return np.linalg.norm(self.center - point)
|
||||
|
||||
def contains(self, point: np.ndarray):
|
||||
return self.x_min <= point[0] <= self.x_max and self.y_min <= point[1] <= self.y_max
|
||||
|
||||
|
||||
def get_mask_bounding_box(mask, label: str):
|
||||
where = np.where(mask)
|
||||
x_min = np.min(where[1])
|
||||
y_min = np.min(where[0])
|
||||
x_max = np.max(where[1])
|
||||
y_max = np.max(where[0])
|
||||
return BoundingBox(
|
||||
x_min / mask.shape[1],
|
||||
y_min / mask.shape[0],
|
||||
x_max / mask.shape[1],
|
||||
y_max / mask.shape[0],
|
||||
label
|
||||
)
|
||||
|
||||
|
||||
def get_bounding_boxes(masks, labels):
|
||||
bounding_boxes = []
|
||||
for mask, label in zip(masks, labels):
|
||||
bounding_box = get_mask_bounding_box(mask, label)
|
||||
bounding_boxes.append(bounding_box)
|
||||
return bounding_boxes
|
||||
26
segment_anything_ui/utils/precompute_folder.py
Normal file
26
segment_anything_ui/utils/precompute_folder.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import rich
|
||||
from PIL import Image
|
||||
import safetensors
|
||||
from segment_anything import sam_model_registry
|
||||
|
||||
from segment_anything_ui.modeling.storable_sam import StorableSam
|
||||
from segment_anything_ui.config import Config
|
||||
|
||||
config = Config()
|
||||
sam = sam_model_registry[config.get_sam_model_name()](checkpoint=config.default_weights)
|
||||
allowed_extensions = [".jpg", ".png", ".tif", ".tiff"]
|
||||
|
||||
|
||||
def load_images_from_folder(folder_path):
|
||||
images = []
|
||||
for filename in os.listdir(folder_path):
|
||||
allowed_extensions = [".jpg", ".png"]
|
||||
if any(filename.endswith(ext) for ext in allowed_extensions):
|
||||
img = Image.open(os.path.join(folder_path, filename))
|
||||
return images
|
||||
53
segment_anything_ui/utils/shapes.py
Normal file
53
segment_anything_ui/utils/shapes.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import dataclasses
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PySide6.QtCore import QPoint
|
||||
from PySide6.QtGui import QPolygon
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BoundingBox:
|
||||
xstart: float | int
|
||||
ystart: float | int
|
||||
xend: float | int = -1.
|
||||
yend: float | int = -1.
|
||||
|
||||
def to_numpy(self):
|
||||
return np.array([self.xstart, self.ystart, self.xend, self.yend])
|
||||
|
||||
def scale(self, sx, sy):
|
||||
return BoundingBox(
|
||||
xstart=self.xstart * sx,
|
||||
ystart=self.ystart * sy,
|
||||
xend=self.xend * sx,
|
||||
yend=self.yend * sy
|
||||
)
|
||||
|
||||
def to_int(self):
|
||||
return BoundingBox(
|
||||
xstart=int(self.xstart),
|
||||
ystart=int(self.ystart),
|
||||
xend=int(self.xend),
|
||||
yend=int(self.yend)
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Polygon:
|
||||
points: list = dataclasses.field(default_factory=list)
|
||||
|
||||
def to_numpy(self):
|
||||
return np.array(self.points).reshape(-1, 2)
|
||||
|
||||
def to_mask(self, num_rows, num_cols):
|
||||
mask = np.zeros((num_rows, num_cols))
|
||||
mask = cv2.fillPoly(mask, pts=[self.to_numpy(), ], color=255)
|
||||
return mask
|
||||
|
||||
def is_plotable(self):
|
||||
return len(self.points) > 3
|
||||
|
||||
def to_qpolygon(self):
|
||||
return QPolygon([
|
||||
QPoint(x, y) for x, y in self.points
|
||||
])
|
||||
0
segment_anything_ui/utils/tooltips.py
Normal file
0
segment_anything_ui/utils/tooltips.py
Normal file
Reference in New Issue
Block a user