initial_tune
This commit is contained in:
445
segment_anything_ui/annotator.py
Normal file
445
segment_anything_ui/annotator.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import dataclasses
|
||||
from typing import Callable
|
||||
import uuid
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
|
||||
from segment_anything import SamPredictor
|
||||
from segment_anything.build_sam import Sam
|
||||
from segment_anything_ui.model_builder import (
|
||||
get_predictor, get_mask_generator, SamPredictor)
|
||||
try:
|
||||
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
class EfficientViTSamPredictor:
|
||||
pass
|
||||
|
||||
class EfficientViTSam:
|
||||
pass
|
||||
|
||||
from skimage.measure import regionprops
|
||||
import torch
|
||||
|
||||
from segment_anything_ui.utils.shapes import BoundingBox
|
||||
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
|
||||
|
||||
def get_cmap(n, name='hsv'):
|
||||
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
|
||||
RGB color; the keyword argument name must be a standard mpl colormap name.'''
|
||||
try:
|
||||
return plt.cm.get_cmap(name, n)
|
||||
except:
|
||||
return plt.get_cmap(name, n)
|
||||
|
||||
def crop_image(
|
||||
image,
|
||||
box: BoundingBox | None = None,
|
||||
image_shape: tuple[int, int] | None = None
|
||||
):
|
||||
if image_shape is None:
|
||||
image_shape = image.shape[:2][::-1]
|
||||
if box is None:
|
||||
return cv2.resize(image, image_shape)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
|
||||
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
|
||||
|
||||
|
||||
def insert_image(image, box: BoundingBox | None = None):
|
||||
new_image = np.zeros_like(image)
|
||||
if box is None:
|
||||
new_image = image
|
||||
else:
|
||||
if len(image.shape) == 2:
|
||||
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
|
||||
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
||||
else:
|
||||
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
|
||||
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
|
||||
return new_image
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class AutomaticMaskGeneratorSettings:
|
||||
points_per_side: int = 32
|
||||
pred_iou_thresh: float = 0.88
|
||||
stability_score_thresh: float = 0.95
|
||||
stability_score_offset: float = 1.0
|
||||
box_nms_thresh: float = 0.7
|
||||
crop_n_layers: int = 0
|
||||
crop_nms_thresh: float = 0.7
|
||||
|
||||
|
||||
class LabelValueParam(QWidget):
|
||||
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
|
||||
super().__init__(parent)
|
||||
self.layout = QVBoxLayout(self)
|
||||
self.layout.setSpacing(0)
|
||||
self.layout.setContentsMargins(0, 0, 0, 0)
|
||||
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
|
||||
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
|
||||
self.layout.addWidget(self.label)
|
||||
self.layout.addWidget(self.value)
|
||||
self.converter = value_type_converter
|
||||
|
||||
def get_value(self):
|
||||
return self.converter(self.value.text())
|
||||
|
||||
|
||||
class CustomForm(QWidget):
|
||||
|
||||
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
|
||||
super().__init__(parent)
|
||||
self.layout = QVBoxLayout(self)
|
||||
self.layout.setSpacing(0)
|
||||
self.layout.setContentsMargins(0, 0, 0, 0)
|
||||
self.widgets = []
|
||||
|
||||
for field in dataclasses.fields(automatic_mask_generator_settings):
|
||||
widget = LabelValueParam(field.name, str(field.default), field.type)
|
||||
self.widgets.append(widget)
|
||||
self.layout.addWidget(widget)
|
||||
|
||||
def get_values(self):
|
||||
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
|
||||
|
||||
|
||||
class BoundingBoxAnnotation:
|
||||
def __init__(self) -> None:
|
||||
self.bounding_boxes: list[BoundingBox] = []
|
||||
self.bounding_box_id: int = -1
|
||||
|
||||
def append(self, bounding_box: BoundingBox):
|
||||
self.bounding_boxes.append(bounding_box)
|
||||
|
||||
def find_closest_bounding_box(self, point: np.ndarray):
|
||||
closest_bounding_box = None
|
||||
closest_bounding_box_id = -1
|
||||
min_distance = float('inf')
|
||||
for idx, bounding_box in enumerate(self.bounding_boxes):
|
||||
distance = bounding_box.distance_to(point)
|
||||
if distance < min_distance and bounding_box.contains(point):
|
||||
min_distance = distance
|
||||
closest_bounding_box = bounding_box
|
||||
closest_bounding_box_id = idx
|
||||
self.bounding_box_id = closest_bounding_box_id
|
||||
return closest_bounding_box, closest_bounding_box_id
|
||||
|
||||
def get_bounding_box(self, bounding_box_id: int):
|
||||
return self.bounding_boxes[bounding_box_id]
|
||||
|
||||
def get_current_bounding_box(self):
|
||||
return self.bounding_boxes[self.bounding_box_id]
|
||||
|
||||
def set_current_bounding_box(self, bounding_box: BoundingBox):
|
||||
self.bounding_boxes[self.bounding_box_id] = bounding_box
|
||||
|
||||
def remove(self, mask_uid: str):
|
||||
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
|
||||
if bounding_box_id is None:
|
||||
return
|
||||
bounding_box = self.bounding_boxes.pop(bounding_box_id)
|
||||
if self.bounding_box_id >= bounding_box_id:
|
||||
self.bounding_box_id -= 1
|
||||
return bounding_box
|
||||
|
||||
def remove_by_id(self, bounding_box_id: int):
|
||||
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
|
||||
self.remove(mask_uid)
|
||||
return mask_uid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bounding_boxes)
|
||||
|
||||
|
||||
class MasksAnnotation:
|
||||
DEFAULT_LABEL = "default"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.masks = []
|
||||
self.label_map = {}
|
||||
self.masks_uids: list[str] = []
|
||||
self.mask_id: int = -1
|
||||
|
||||
def add_mask(self, mask, label: str | None = None):
|
||||
self.masks.append(mask)
|
||||
self.masks_uids.append(str(uuid.uuid4()))
|
||||
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
|
||||
return self.masks_uids[-1]
|
||||
|
||||
def add_label(self, mask_id: int, label: str):
|
||||
self.label_map[mask_id] = label
|
||||
|
||||
def get_mask(self, mask_id: int):
|
||||
return self.masks[mask_id]
|
||||
|
||||
def get_label(self, mask_id: int):
|
||||
return self.label_map[mask_id]
|
||||
|
||||
def get_current_mask(self):
|
||||
return self.masks[self.mask_id]
|
||||
|
||||
def set_current_mask(self, mask, label: str = None):
|
||||
self.masks[self.mask_id] = mask
|
||||
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
|
||||
|
||||
def __getitem__(self, mask_id: int):
|
||||
return self.get_mask(mask_id)
|
||||
|
||||
def __setitem__(self, mask_id: int, value):
|
||||
self.masks[mask_id] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.masks)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(zip(self.masks, self.label_map.values()))
|
||||
|
||||
def __next__(self):
|
||||
if self.mask_id >= len(self.masks):
|
||||
raise StopIteration
|
||||
return self.masks[self.mask_id]
|
||||
|
||||
def append(self, mask, label: str | None = None):
|
||||
return self.add_mask(mask, label)
|
||||
|
||||
def pop_by_uuid(self, mask_uid: str):
|
||||
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
|
||||
if mask_id is None:
|
||||
return
|
||||
return self.pop(mask_id)
|
||||
|
||||
def pop(self, mask_id: int = -1):
|
||||
_ = self.masks.pop(mask_id)
|
||||
mask_uid = self.masks_uids.pop(mask_id)
|
||||
self.label_map.pop(mask_id + 1)
|
||||
new_label_map = {}
|
||||
for index, value in enumerate(self.label_map.values()):
|
||||
new_label_map[index + 1] = value
|
||||
self.label_map = new_label_map
|
||||
return mask_uid
|
||||
|
||||
@classmethod
|
||||
def from_masks(cls, masks, labels: list[str] | None = None):
|
||||
annotation = cls()
|
||||
if labels is None:
|
||||
labels = [None] * len(masks)
|
||||
for mask, label in zip(masks, labels):
|
||||
annotation.append(mask, label)
|
||||
return annotation
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class Annotator:
|
||||
sam: Sam | EfficientViTSam | None = None
|
||||
embedding: torch.Tensor | None = None
|
||||
image: np.ndarray | None = None
|
||||
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
|
||||
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
|
||||
predictor: SamPredictor | EfficientViTSamPredictor | None = None
|
||||
visualization: np.ndarray | None = None
|
||||
last_mask: np.ndarray | None = None
|
||||
partial_mask: np.ndarray | None = None
|
||||
merged_mask: np.ndarray | None = None
|
||||
parent: QWidget | None = None
|
||||
cmap: plt.cm = None
|
||||
original_image: np.ndarray | None = None
|
||||
zoomed_bounding_box: BoundingBox | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.MAX_MASKS = 10
|
||||
self.cmap = get_cmap(self.MAX_MASKS)
|
||||
|
||||
def set_image(self, image: np.ndarray):
|
||||
self.image = image
|
||||
return self
|
||||
|
||||
def make_embedding(self):
|
||||
if self.sam is None:
|
||||
return
|
||||
self.predictor = get_predictor(self.sam)
|
||||
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
|
||||
|
||||
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
|
||||
generator = get_mask_generator(
|
||||
sam=self.sam,
|
||||
**dataclasses.asdict(settings)
|
||||
)
|
||||
masks = generator.generate(self.image)
|
||||
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
|
||||
label = self.parent.annotation_layout.label_picker.currentItem().text()
|
||||
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
|
||||
self.cmap = get_cmap(len(self.masks))
|
||||
|
||||
def make_prediction(self, annotation: dict):
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=annotation["points"],
|
||||
point_labels=annotation["labels"],
|
||||
box=annotation["bounding_boxes"],
|
||||
multimask_output=False
|
||||
)
|
||||
mask = masks[0]
|
||||
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
|
||||
|
||||
def pick_partial_mask(self):
|
||||
if self.partial_mask is None:
|
||||
self.partial_mask = self.last_mask.copy()
|
||||
else:
|
||||
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
|
||||
self.last_mask = None
|
||||
|
||||
def move_current_mask_to_background(self):
|
||||
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
|
||||
|
||||
def merge_masks(self):
|
||||
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
|
||||
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
|
||||
self.merged_mask = None
|
||||
|
||||
def visualize_last_mask(self, label: str | None = None):
|
||||
last_mask = np.zeros_like(self.image)
|
||||
last_mask[:, :, 1] = self.last_mask
|
||||
if self.partial_mask is not None:
|
||||
last_mask[:, :, 0] = self.partial_mask
|
||||
if self.merged_mask is not None:
|
||||
last_mask[:, :, 2] = self.merged_mask
|
||||
if label is not None:
|
||||
props = regionprops(self.last_mask)[0]
|
||||
cv2.putText(
|
||||
last_mask,
|
||||
label,
|
||||
(int(props.centroid[1]), int(props.centroid[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
if self.is_show_bounding_boxes:
|
||||
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
|
||||
cv2.rectangle(
|
||||
last_mask,
|
||||
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
||||
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
|
||||
(0, 255, 0),
|
||||
2
|
||||
)
|
||||
cv2.putText(
|
||||
last_mask,
|
||||
label,
|
||||
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
|
||||
0.8, last_mask, 0.5, 0)
|
||||
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
|
||||
|
||||
def _visualize_mask(self) -> tuple:
|
||||
mask_argmax = self.make_instance_mask()
|
||||
visualization = np.zeros_like(self.image)
|
||||
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
|
||||
for i in range(1, np.amax(mask_argmax) + 1):
|
||||
color = self.cmap(i)
|
||||
single_mask = np.zeros_like(mask_argmax)
|
||||
single_mask[mask_argmax == i] = 1
|
||||
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
|
||||
border += single_mask - cv2.erode(
|
||||
single_mask, np.ones((3, 3), np.uint8), iterations=1)
|
||||
label = self.masks.get_label(i)
|
||||
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
|
||||
if np.isnan(single_mask_center).any():
|
||||
continue
|
||||
if self.parent.settings.is_show_text():
|
||||
cv2.putText(
|
||||
visualization,
|
||||
label,
|
||||
(int(single_mask_center[1]), int(single_mask_center[0])),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
0.5,
|
||||
[255, 255, 255],
|
||||
1
|
||||
)
|
||||
if self.is_show_bounding_boxes:
|
||||
bounding_boxes = self.get_bounding_boxes()
|
||||
for idx, bounding_box in enumerate(bounding_boxes):
|
||||
cv2.rectangle(
|
||||
visualization,
|
||||
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
||||
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
|
||||
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
|
||||
2
|
||||
)
|
||||
cv2.putText(
|
||||
visualization,
|
||||
bounding_box.label,
|
||||
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
[255, 255, 255],
|
||||
2
|
||||
)
|
||||
border = (border == 0).astype(np.uint8)
|
||||
return visualization, border
|
||||
|
||||
def has_annotations(self):
|
||||
return len(self.masks) > 0
|
||||
|
||||
def make_instance_mask(self):
|
||||
background = np.zeros_like(self.masks[0]) + 1
|
||||
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
|
||||
return mask_argmax
|
||||
|
||||
def get_bounding_boxes(self):
|
||||
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
|
||||
|
||||
def merge_image_visualization(self):
|
||||
image = self.image.copy()
|
||||
if not len(self.masks):
|
||||
return crop_image(image, self.zoomed_bounding_box)
|
||||
visualization, border = self._visualize_mask()
|
||||
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
|
||||
return crop_image(self.visualization, self.zoomed_bounding_box)
|
||||
|
||||
def remove_last_mask(self):
|
||||
mask_id = len(self.masks)
|
||||
self.masks.pop(mask_id)
|
||||
self.bounding_boxes.remove(mask_id)
|
||||
|
||||
def make_labels(self):
|
||||
return self.masks.label_map
|
||||
|
||||
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
|
||||
if self.partial_mask is not None:
|
||||
last_mask = self.partial_mask
|
||||
self.partial_mask = None
|
||||
else:
|
||||
last_mask = self.last_mask
|
||||
mask_uid = self.masks.add_mask(last_mask, label=label)
|
||||
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
|
||||
corresponding_bounding_box.mask_uid = mask_uid
|
||||
self.bounding_boxes.append(corresponding_bounding_box)
|
||||
if len(self.masks) >= self.MAX_MASKS:
|
||||
self.MAX_MASKS += 10
|
||||
self.cmap = get_cmap(self.MAX_MASKS)
|
||||
|
||||
@property
|
||||
def is_show_bounding_boxes(self):
|
||||
return self.parent.settings.is_show_bounding_boxes()
|
||||
|
||||
def clear_last_masks(self):
|
||||
self.last_mask = None
|
||||
self.partial_mask = None
|
||||
self.visualization = None
|
||||
|
||||
def clear(self):
|
||||
self.last_mask = None
|
||||
self.visualization = None
|
||||
self.masks = MasksAnnotation()
|
||||
self.bounding_boxes = BoundingBoxAnnotation()
|
||||
self.partial_mask = None
|
||||
Reference in New Issue
Block a user