446 lines
17 KiB
Python
446 lines
17 KiB
Python
import dataclasses
|
|
from typing import Callable
|
|
import uuid
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from PySide6.QtCore import Qt
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
|
|
from segment_anything import SamPredictor
|
|
from segment_anything.build_sam import Sam
|
|
from segment_anything_ui.model_builder import (
|
|
get_predictor, get_mask_generator, SamPredictor)
|
|
try:
|
|
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
|
|
except (ImportError, ModuleNotFoundError):
|
|
class EfficientViTSamPredictor:
|
|
pass
|
|
|
|
class EfficientViTSam:
|
|
pass
|
|
|
|
from skimage.measure import regionprops
|
|
import torch
|
|
|
|
from segment_anything_ui.utils.shapes import BoundingBox
|
|
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
|
|
|
|
def get_cmap(n, name='hsv'):
|
|
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
|
|
RGB color; the keyword argument name must be a standard mpl colormap name.'''
|
|
try:
|
|
return plt.cm.get_cmap(name, n)
|
|
except:
|
|
return plt.get_cmap(name, n)
|
|
|
|
def crop_image(
|
|
image,
|
|
box: BoundingBox | None = None,
|
|
image_shape: tuple[int, int] | None = None
|
|
):
|
|
if image_shape is None:
|
|
image_shape = image.shape[:2][::-1]
|
|
if box is None:
|
|
return cv2.resize(image, image_shape)
|
|
|
|
if len(image.shape) == 2:
|
|
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
|
|
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
|
|
|
|
|
|
def insert_image(image, box: BoundingBox | None = None):
|
|
new_image = np.zeros_like(image)
|
|
if box is None:
|
|
new_image = image
|
|
else:
|
|
if len(image.shape) == 2:
|
|
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
|
|
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
|
else:
|
|
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
|
|
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
|
return new_image
|
|
|
|
|
|
@dataclasses.dataclass()
|
|
class AutomaticMaskGeneratorSettings:
|
|
points_per_side: int = 32
|
|
pred_iou_thresh: float = 0.88
|
|
stability_score_thresh: float = 0.95
|
|
stability_score_offset: float = 1.0
|
|
box_nms_thresh: float = 0.7
|
|
crop_n_layers: int = 0
|
|
crop_nms_thresh: float = 0.7
|
|
|
|
|
|
class LabelValueParam(QWidget):
|
|
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
|
|
super().__init__(parent)
|
|
self.layout = QVBoxLayout(self)
|
|
self.layout.setSpacing(0)
|
|
self.layout.setContentsMargins(0, 0, 0, 0)
|
|
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
|
|
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
|
|
self.layout.addWidget(self.label)
|
|
self.layout.addWidget(self.value)
|
|
self.converter = value_type_converter
|
|
|
|
def get_value(self):
|
|
return self.converter(self.value.text())
|
|
|
|
|
|
class CustomForm(QWidget):
|
|
|
|
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
|
|
super().__init__(parent)
|
|
self.layout = QVBoxLayout(self)
|
|
self.layout.setSpacing(0)
|
|
self.layout.setContentsMargins(0, 0, 0, 0)
|
|
self.widgets = []
|
|
|
|
for field in dataclasses.fields(automatic_mask_generator_settings):
|
|
widget = LabelValueParam(field.name, str(field.default), field.type)
|
|
self.widgets.append(widget)
|
|
self.layout.addWidget(widget)
|
|
|
|
def get_values(self):
|
|
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
|
|
|
|
|
|
class BoundingBoxAnnotation:
|
|
def __init__(self) -> None:
|
|
self.bounding_boxes: list[BoundingBox] = []
|
|
self.bounding_box_id: int = -1
|
|
|
|
def append(self, bounding_box: BoundingBox):
|
|
self.bounding_boxes.append(bounding_box)
|
|
|
|
def find_closest_bounding_box(self, point: np.ndarray):
|
|
closest_bounding_box = None
|
|
closest_bounding_box_id = -1
|
|
min_distance = float('inf')
|
|
for idx, bounding_box in enumerate(self.bounding_boxes):
|
|
distance = bounding_box.distance_to(point)
|
|
if distance < min_distance and bounding_box.contains(point):
|
|
min_distance = distance
|
|
closest_bounding_box = bounding_box
|
|
closest_bounding_box_id = idx
|
|
self.bounding_box_id = closest_bounding_box_id
|
|
return closest_bounding_box, closest_bounding_box_id
|
|
|
|
def get_bounding_box(self, bounding_box_id: int):
|
|
return self.bounding_boxes[bounding_box_id]
|
|
|
|
def get_current_bounding_box(self):
|
|
return self.bounding_boxes[self.bounding_box_id]
|
|
|
|
def set_current_bounding_box(self, bounding_box: BoundingBox):
|
|
self.bounding_boxes[self.bounding_box_id] = bounding_box
|
|
|
|
def remove(self, mask_uid: str):
|
|
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
|
|
if bounding_box_id is None:
|
|
return
|
|
bounding_box = self.bounding_boxes.pop(bounding_box_id)
|
|
if self.bounding_box_id >= bounding_box_id:
|
|
self.bounding_box_id -= 1
|
|
return bounding_box
|
|
|
|
def remove_by_id(self, bounding_box_id: int):
|
|
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
|
|
self.remove(mask_uid)
|
|
return mask_uid
|
|
|
|
def __len__(self):
|
|
return len(self.bounding_boxes)
|
|
|
|
|
|
class MasksAnnotation:
|
|
DEFAULT_LABEL = "default"
|
|
|
|
def __init__(self) -> None:
|
|
self.masks = []
|
|
self.label_map = {}
|
|
self.masks_uids: list[str] = []
|
|
self.mask_id: int = -1
|
|
|
|
def add_mask(self, mask, label: str | None = None):
|
|
self.masks.append(mask)
|
|
self.masks_uids.append(str(uuid.uuid4()))
|
|
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
|
|
return self.masks_uids[-1]
|
|
|
|
def add_label(self, mask_id: int, label: str):
|
|
self.label_map[mask_id] = label
|
|
|
|
def get_mask(self, mask_id: int):
|
|
return self.masks[mask_id]
|
|
|
|
def get_label(self, mask_id: int):
|
|
return self.label_map[mask_id]
|
|
|
|
def get_current_mask(self):
|
|
return self.masks[self.mask_id]
|
|
|
|
def set_current_mask(self, mask, label: str = None):
|
|
self.masks[self.mask_id] = mask
|
|
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
|
|
|
|
def __getitem__(self, mask_id: int):
|
|
return self.get_mask(mask_id)
|
|
|
|
def __setitem__(self, mask_id: int, value):
|
|
self.masks[mask_id] = value
|
|
|
|
def __len__(self):
|
|
return len(self.masks)
|
|
|
|
def __iter__(self):
|
|
return iter(zip(self.masks, self.label_map.values()))
|
|
|
|
def __next__(self):
|
|
if self.mask_id >= len(self.masks):
|
|
raise StopIteration
|
|
return self.masks[self.mask_id]
|
|
|
|
def append(self, mask, label: str | None = None):
|
|
return self.add_mask(mask, label)
|
|
|
|
def pop_by_uuid(self, mask_uid: str):
|
|
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
|
|
if mask_id is None:
|
|
return
|
|
return self.pop(mask_id)
|
|
|
|
def pop(self, mask_id: int = -1):
|
|
_ = self.masks.pop(mask_id)
|
|
mask_uid = self.masks_uids.pop(mask_id)
|
|
self.label_map.pop(mask_id + 1)
|
|
new_label_map = {}
|
|
for index, value in enumerate(self.label_map.values()):
|
|
new_label_map[index + 1] = value
|
|
self.label_map = new_label_map
|
|
return mask_uid
|
|
|
|
@classmethod
|
|
def from_masks(cls, masks, labels: list[str] | None = None):
|
|
annotation = cls()
|
|
if labels is None:
|
|
labels = [None] * len(masks)
|
|
for mask, label in zip(masks, labels):
|
|
annotation.append(mask, label)
|
|
return annotation
|
|
|
|
|
|
@dataclasses.dataclass()
|
|
class Annotator:
|
|
sam: Sam | EfficientViTSam | None = None
|
|
embedding: torch.Tensor | None = None
|
|
image: np.ndarray | None = None
|
|
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
|
|
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
|
|
predictor: SamPredictor | EfficientViTSamPredictor | None = None
|
|
visualization: np.ndarray | None = None
|
|
last_mask: np.ndarray | None = None
|
|
partial_mask: np.ndarray | None = None
|
|
merged_mask: np.ndarray | None = None
|
|
parent: QWidget | None = None
|
|
cmap: plt.cm = None
|
|
original_image: np.ndarray | None = None
|
|
zoomed_bounding_box: BoundingBox | None = None
|
|
|
|
def __post_init__(self):
|
|
self.MAX_MASKS = 10
|
|
self.cmap = get_cmap(self.MAX_MASKS)
|
|
|
|
def set_image(self, image: np.ndarray):
|
|
self.image = image
|
|
return self
|
|
|
|
def make_embedding(self):
|
|
if self.sam is None:
|
|
return
|
|
self.predictor = get_predictor(self.sam)
|
|
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
|
|
|
|
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
|
|
generator = get_mask_generator(
|
|
sam=self.sam,
|
|
**dataclasses.asdict(settings)
|
|
)
|
|
masks = generator.generate(self.image)
|
|
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
|
|
label = self.parent.annotation_layout.label_picker.currentItem().text()
|
|
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
|
|
self.cmap = get_cmap(len(self.masks))
|
|
|
|
def make_prediction(self, annotation: dict):
|
|
masks, scores, logits = self.predictor.predict(
|
|
point_coords=annotation["points"],
|
|
point_labels=annotation["labels"],
|
|
box=annotation["bounding_boxes"],
|
|
multimask_output=False
|
|
)
|
|
mask = masks[0]
|
|
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
|
|
|
|
def pick_partial_mask(self):
|
|
if self.partial_mask is None:
|
|
self.partial_mask = self.last_mask.copy()
|
|
else:
|
|
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
|
|
self.last_mask = None
|
|
|
|
def move_current_mask_to_background(self):
|
|
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
|
|
|
|
def merge_masks(self):
|
|
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
|
|
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
|
|
self.merged_mask = None
|
|
|
|
def visualize_last_mask(self, label: str | None = None):
|
|
last_mask = np.zeros_like(self.image)
|
|
last_mask[:, :, 1] = self.last_mask
|
|
if self.partial_mask is not None:
|
|
last_mask[:, :, 0] = self.partial_mask
|
|
if self.merged_mask is not None:
|
|
last_mask[:, :, 2] = self.merged_mask
|
|
if label is not None:
|
|
props = regionprops(self.last_mask)[0]
|
|
cv2.putText(
|
|
last_mask,
|
|
label,
|
|
(int(props.centroid[1]), int(props.centroid[0])),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.0,
|
|
[255, 255, 255],
|
|
2
|
|
)
|
|
if self.is_show_bounding_boxes:
|
|
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
|
|
cv2.rectangle(
|
|
last_mask,
|
|
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
|
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
|
|
(0, 255, 0),
|
|
2
|
|
)
|
|
cv2.putText(
|
|
last_mask,
|
|
label,
|
|
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.0,
|
|
[255, 255, 255],
|
|
2
|
|
)
|
|
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
|
|
0.8, last_mask, 0.5, 0)
|
|
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
|
|
|
|
def _visualize_mask(self) -> tuple:
|
|
mask_argmax = self.make_instance_mask()
|
|
visualization = np.zeros_like(self.image)
|
|
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
|
|
for i in range(1, np.amax(mask_argmax) + 1):
|
|
color = self.cmap(i)
|
|
single_mask = np.zeros_like(mask_argmax)
|
|
single_mask[mask_argmax == i] = 1
|
|
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
|
|
border += single_mask - cv2.erode(
|
|
single_mask, np.ones((3, 3), np.uint8), iterations=1)
|
|
label = self.masks.get_label(i)
|
|
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
|
|
if np.isnan(single_mask_center).any():
|
|
continue
|
|
if self.parent.settings.is_show_text():
|
|
cv2.putText(
|
|
visualization,
|
|
label,
|
|
(int(single_mask_center[1]), int(single_mask_center[0])),
|
|
cv2.FONT_HERSHEY_PLAIN,
|
|
0.5,
|
|
[255, 255, 255],
|
|
1
|
|
)
|
|
if self.is_show_bounding_boxes:
|
|
bounding_boxes = self.get_bounding_boxes()
|
|
for idx, bounding_box in enumerate(bounding_boxes):
|
|
cv2.rectangle(
|
|
visualization,
|
|
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
|
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
|
|
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
|
|
2
|
|
)
|
|
cv2.putText(
|
|
visualization,
|
|
bounding_box.label,
|
|
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.0,
|
|
[255, 255, 255],
|
|
2
|
|
)
|
|
border = (border == 0).astype(np.uint8)
|
|
return visualization, border
|
|
|
|
def has_annotations(self):
|
|
return len(self.masks) > 0
|
|
|
|
def make_instance_mask(self):
|
|
background = np.zeros_like(self.masks[0]) + 1
|
|
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
|
|
return mask_argmax
|
|
|
|
def get_bounding_boxes(self):
|
|
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
|
|
|
|
def merge_image_visualization(self):
|
|
image = self.image.copy()
|
|
if not len(self.masks):
|
|
return crop_image(image, self.zoomed_bounding_box)
|
|
visualization, border = self._visualize_mask()
|
|
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
|
|
return crop_image(self.visualization, self.zoomed_bounding_box)
|
|
|
|
def remove_last_mask(self):
|
|
mask_id = len(self.masks)
|
|
self.masks.pop(mask_id)
|
|
self.bounding_boxes.remove(mask_id)
|
|
|
|
def make_labels(self):
|
|
return self.masks.label_map
|
|
|
|
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
|
|
if self.partial_mask is not None:
|
|
last_mask = self.partial_mask
|
|
self.partial_mask = None
|
|
else:
|
|
last_mask = self.last_mask
|
|
mask_uid = self.masks.add_mask(last_mask, label=label)
|
|
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
|
|
corresponding_bounding_box.mask_uid = mask_uid
|
|
self.bounding_boxes.append(corresponding_bounding_box)
|
|
if len(self.masks) >= self.MAX_MASKS:
|
|
self.MAX_MASKS += 10
|
|
self.cmap = get_cmap(self.MAX_MASKS)
|
|
|
|
@property
|
|
def is_show_bounding_boxes(self):
|
|
return self.parent.settings.is_show_bounding_boxes()
|
|
|
|
def clear_last_masks(self):
|
|
self.last_mask = None
|
|
self.partial_mask = None
|
|
self.visualization = None
|
|
|
|
def clear(self):
|
|
self.last_mask = None
|
|
self.visualization = None
|
|
self.masks = MasksAnnotation()
|
|
self.bounding_boxes = BoundingBoxAnnotation()
|
|
self.partial_mask = None
|