Files
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

446 lines
17 KiB
Python

import dataclasses
from typing import Callable
import uuid
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
from segment_anything import SamPredictor
from segment_anything.build_sam import Sam
from segment_anything_ui.model_builder import (
get_predictor, get_mask_generator, SamPredictor)
try:
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
except (ImportError, ModuleNotFoundError):
class EfficientViTSamPredictor:
pass
class EfficientViTSam:
pass
from skimage.measure import regionprops
import torch
from segment_anything_ui.utils.shapes import BoundingBox
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
def get_cmap(n, name='hsv'):
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
RGB color; the keyword argument name must be a standard mpl colormap name.'''
try:
return plt.cm.get_cmap(name, n)
except:
return plt.get_cmap(name, n)
def crop_image(
image,
box: BoundingBox | None = None,
image_shape: tuple[int, int] | None = None
):
if image_shape is None:
image_shape = image.shape[:2][::-1]
if box is None:
return cv2.resize(image, image_shape)
if len(image.shape) == 2:
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
def insert_image(image, box: BoundingBox | None = None):
new_image = np.zeros_like(image)
if box is None:
new_image = image
else:
if len(image.shape) == 2:
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
else:
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
return new_image
@dataclasses.dataclass()
class AutomaticMaskGeneratorSettings:
points_per_side: int = 32
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
class LabelValueParam(QWidget):
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
self.layout.addWidget(self.label)
self.layout.addWidget(self.value)
self.converter = value_type_converter
def get_value(self):
return self.converter(self.value.text())
class CustomForm(QWidget):
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.widgets = []
for field in dataclasses.fields(automatic_mask_generator_settings):
widget = LabelValueParam(field.name, str(field.default), field.type)
self.widgets.append(widget)
self.layout.addWidget(widget)
def get_values(self):
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
class BoundingBoxAnnotation:
def __init__(self) -> None:
self.bounding_boxes: list[BoundingBox] = []
self.bounding_box_id: int = -1
def append(self, bounding_box: BoundingBox):
self.bounding_boxes.append(bounding_box)
def find_closest_bounding_box(self, point: np.ndarray):
closest_bounding_box = None
closest_bounding_box_id = -1
min_distance = float('inf')
for idx, bounding_box in enumerate(self.bounding_boxes):
distance = bounding_box.distance_to(point)
if distance < min_distance and bounding_box.contains(point):
min_distance = distance
closest_bounding_box = bounding_box
closest_bounding_box_id = idx
self.bounding_box_id = closest_bounding_box_id
return closest_bounding_box, closest_bounding_box_id
def get_bounding_box(self, bounding_box_id: int):
return self.bounding_boxes[bounding_box_id]
def get_current_bounding_box(self):
return self.bounding_boxes[self.bounding_box_id]
def set_current_bounding_box(self, bounding_box: BoundingBox):
self.bounding_boxes[self.bounding_box_id] = bounding_box
def remove(self, mask_uid: str):
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
if bounding_box_id is None:
return
bounding_box = self.bounding_boxes.pop(bounding_box_id)
if self.bounding_box_id >= bounding_box_id:
self.bounding_box_id -= 1
return bounding_box
def remove_by_id(self, bounding_box_id: int):
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
self.remove(mask_uid)
return mask_uid
def __len__(self):
return len(self.bounding_boxes)
class MasksAnnotation:
DEFAULT_LABEL = "default"
def __init__(self) -> None:
self.masks = []
self.label_map = {}
self.masks_uids: list[str] = []
self.mask_id: int = -1
def add_mask(self, mask, label: str | None = None):
self.masks.append(mask)
self.masks_uids.append(str(uuid.uuid4()))
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
return self.masks_uids[-1]
def add_label(self, mask_id: int, label: str):
self.label_map[mask_id] = label
def get_mask(self, mask_id: int):
return self.masks[mask_id]
def get_label(self, mask_id: int):
return self.label_map[mask_id]
def get_current_mask(self):
return self.masks[self.mask_id]
def set_current_mask(self, mask, label: str = None):
self.masks[self.mask_id] = mask
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
def __getitem__(self, mask_id: int):
return self.get_mask(mask_id)
def __setitem__(self, mask_id: int, value):
self.masks[mask_id] = value
def __len__(self):
return len(self.masks)
def __iter__(self):
return iter(zip(self.masks, self.label_map.values()))
def __next__(self):
if self.mask_id >= len(self.masks):
raise StopIteration
return self.masks[self.mask_id]
def append(self, mask, label: str | None = None):
return self.add_mask(mask, label)
def pop_by_uuid(self, mask_uid: str):
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
if mask_id is None:
return
return self.pop(mask_id)
def pop(self, mask_id: int = -1):
_ = self.masks.pop(mask_id)
mask_uid = self.masks_uids.pop(mask_id)
self.label_map.pop(mask_id + 1)
new_label_map = {}
for index, value in enumerate(self.label_map.values()):
new_label_map[index + 1] = value
self.label_map = new_label_map
return mask_uid
@classmethod
def from_masks(cls, masks, labels: list[str] | None = None):
annotation = cls()
if labels is None:
labels = [None] * len(masks)
for mask, label in zip(masks, labels):
annotation.append(mask, label)
return annotation
@dataclasses.dataclass()
class Annotator:
sam: Sam | EfficientViTSam | None = None
embedding: torch.Tensor | None = None
image: np.ndarray | None = None
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
predictor: SamPredictor | EfficientViTSamPredictor | None = None
visualization: np.ndarray | None = None
last_mask: np.ndarray | None = None
partial_mask: np.ndarray | None = None
merged_mask: np.ndarray | None = None
parent: QWidget | None = None
cmap: plt.cm = None
original_image: np.ndarray | None = None
zoomed_bounding_box: BoundingBox | None = None
def __post_init__(self):
self.MAX_MASKS = 10
self.cmap = get_cmap(self.MAX_MASKS)
def set_image(self, image: np.ndarray):
self.image = image
return self
def make_embedding(self):
if self.sam is None:
return
self.predictor = get_predictor(self.sam)
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
generator = get_mask_generator(
sam=self.sam,
**dataclasses.asdict(settings)
)
masks = generator.generate(self.image)
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
label = self.parent.annotation_layout.label_picker.currentItem().text()
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
self.cmap = get_cmap(len(self.masks))
def make_prediction(self, annotation: dict):
masks, scores, logits = self.predictor.predict(
point_coords=annotation["points"],
point_labels=annotation["labels"],
box=annotation["bounding_boxes"],
multimask_output=False
)
mask = masks[0]
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
def pick_partial_mask(self):
if self.partial_mask is None:
self.partial_mask = self.last_mask.copy()
else:
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
self.last_mask = None
def move_current_mask_to_background(self):
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
def merge_masks(self):
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
self.merged_mask = None
def visualize_last_mask(self, label: str | None = None):
last_mask = np.zeros_like(self.image)
last_mask[:, :, 1] = self.last_mask
if self.partial_mask is not None:
last_mask[:, :, 0] = self.partial_mask
if self.merged_mask is not None:
last_mask[:, :, 2] = self.merged_mask
if label is not None:
props = regionprops(self.last_mask)[0]
cv2.putText(
last_mask,
label,
(int(props.centroid[1]), int(props.centroid[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
if self.is_show_bounding_boxes:
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
cv2.rectangle(
last_mask,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
(0, 255, 0),
2
)
cv2.putText(
last_mask,
label,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
0.8, last_mask, 0.5, 0)
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
def _visualize_mask(self) -> tuple:
mask_argmax = self.make_instance_mask()
visualization = np.zeros_like(self.image)
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
for i in range(1, np.amax(mask_argmax) + 1):
color = self.cmap(i)
single_mask = np.zeros_like(mask_argmax)
single_mask[mask_argmax == i] = 1
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
border += single_mask - cv2.erode(
single_mask, np.ones((3, 3), np.uint8), iterations=1)
label = self.masks.get_label(i)
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
if np.isnan(single_mask_center).any():
continue
if self.parent.settings.is_show_text():
cv2.putText(
visualization,
label,
(int(single_mask_center[1]), int(single_mask_center[0])),
cv2.FONT_HERSHEY_PLAIN,
0.5,
[255, 255, 255],
1
)
if self.is_show_bounding_boxes:
bounding_boxes = self.get_bounding_boxes()
for idx, bounding_box in enumerate(bounding_boxes):
cv2.rectangle(
visualization,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
2
)
cv2.putText(
visualization,
bounding_box.label,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
border = (border == 0).astype(np.uint8)
return visualization, border
def has_annotations(self):
return len(self.masks) > 0
def make_instance_mask(self):
background = np.zeros_like(self.masks[0]) + 1
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
return mask_argmax
def get_bounding_boxes(self):
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
def merge_image_visualization(self):
image = self.image.copy()
if not len(self.masks):
return crop_image(image, self.zoomed_bounding_box)
visualization, border = self._visualize_mask()
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
return crop_image(self.visualization, self.zoomed_bounding_box)
def remove_last_mask(self):
mask_id = len(self.masks)
self.masks.pop(mask_id)
self.bounding_boxes.remove(mask_id)
def make_labels(self):
return self.masks.label_map
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
if self.partial_mask is not None:
last_mask = self.partial_mask
self.partial_mask = None
else:
last_mask = self.last_mask
mask_uid = self.masks.add_mask(last_mask, label=label)
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
corresponding_bounding_box.mask_uid = mask_uid
self.bounding_boxes.append(corresponding_bounding_box)
if len(self.masks) >= self.MAX_MASKS:
self.MAX_MASKS += 10
self.cmap = get_cmap(self.MAX_MASKS)
@property
def is_show_bounding_boxes(self):
return self.parent.settings.is_show_bounding_boxes()
def clear_last_masks(self):
self.last_mask = None
self.partial_mask = None
self.visualization = None
def clear(self):
self.last_mask = None
self.visualization = None
self.masks = MasksAnnotation()
self.bounding_boxes = BoundingBoxAnnotation()
self.partial_mask = None