219 lines
11 KiB
Python
219 lines
11 KiB
Python
import enum
|
|
import json
|
|
import os
|
|
import numpy as np
|
|
from PySide6.QtCore import Qt
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
|
|
|
|
from segment_anything_ui.draw_label import PaintType
|
|
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
|
|
|
|
|
|
class MergeState(enum.Enum):
|
|
PICKING = enum.auto()
|
|
MERGING = enum.auto()
|
|
|
|
|
|
class AnnotationLayout(QWidget):
|
|
|
|
def __init__(self, parent, config) -> None:
|
|
super().__init__(parent)
|
|
self.config = config
|
|
self.zoom_flag = False
|
|
self.merge_state = MergeState.PICKING
|
|
self.layout = QVBoxLayout(self)
|
|
labels = self._load_labels(config)
|
|
self.layout.setAlignment(Qt.AlignTop)
|
|
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
|
|
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
|
|
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
|
|
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
|
|
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
|
|
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
|
|
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
|
|
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
|
|
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
|
|
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
|
|
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
|
|
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
|
self.label_picker = QListWidget()
|
|
self.label_picker.addItems(labels)
|
|
self.label_picker.setCurrentRow(0)
|
|
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
|
|
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
|
|
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
|
|
self.remove_hidden_masks_line = QLineEdit("0.5")
|
|
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
|
|
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
|
|
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
|
|
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
|
|
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
|
|
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
|
|
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
|
|
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
|
|
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
|
|
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
|
|
|
|
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
|
|
for w in [
|
|
self.add_point,
|
|
self.add_box,
|
|
self.annotate_all,
|
|
self.pick_mask,
|
|
self.pick_bounding_box,
|
|
self.merge_masks,
|
|
self.move_current_mask_background,
|
|
self.cancel_annotation,
|
|
self.delete_mask,
|
|
self.partial_annotation,
|
|
self.save_annotation,
|
|
self.manual_polygon,
|
|
self.label_picker,
|
|
self.zoom_rectangle,
|
|
self.annotation_settings,
|
|
self.remove_hidden_masks,
|
|
self.remove_hidden_masks_label,
|
|
self.remove_hidden_masks_line
|
|
]:
|
|
self.layout.addWidget(w)
|
|
self.add_point.clicked.connect(self.on_add_point)
|
|
self.add_box.clicked.connect(self.on_add_box)
|
|
self.annotate_all.clicked.connect(self.on_annotate_all)
|
|
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
|
|
self.save_annotation.clicked.connect(self.on_save_annotation)
|
|
self.pick_mask.clicked.connect(self.on_pick_mask)
|
|
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
|
|
self.manual_polygon.clicked.connect(self.on_manual_polygon)
|
|
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
|
|
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
|
|
self.merge_masks.clicked.connect(self.on_merge_masks)
|
|
self.partial_annotation.clicked.connect(self.on_partial_annotation)
|
|
self.delete_mask.clicked.connect(self.on_delete_mask)
|
|
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
|
|
|
|
def on_delete_mask(self):
|
|
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
|
|
self.parent().info_label.setText("Deleting mask!")
|
|
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
|
|
self.parent().annotator.bounding_boxes.remove(mask_uid)
|
|
self.parent().annotator.masks.mask_id = -1
|
|
self.parent().annotator.last_mask = None
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
|
|
self.parent().info_label.setText("Deleting bounding box!")
|
|
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
|
|
self.parent().annotator.bounding_boxes.bounding_box_id)
|
|
if mask_uid is not None:
|
|
self.parent().annotator.masks.pop_by_uuid(mask_uid)
|
|
self.parent().annotator.bounding_boxes.bounding_box_id = -1
|
|
self.parent().annotator.last_mask = None
|
|
self.parent().annotator.masks.mask_id = -1
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
else:
|
|
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
|
|
|
|
def on_partial_annotation(self):
|
|
self.parent().info_label.setText("Partial annotation!")
|
|
self.parent().annotator.pick_partial_mask()
|
|
self.parent().image_label.clear()
|
|
|
|
@staticmethod
|
|
def _load_labels(config):
|
|
if not os.path.exists(config.label_file):
|
|
return ["default"]
|
|
with open(config.label_file, "r") as f:
|
|
labels = json.load(f)
|
|
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
|
|
return labels
|
|
|
|
def on_merge_masks(self):
|
|
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
|
if self.merge_state == MergeState.PICKING:
|
|
self.parent().info_label.setText("Pick a mask to merge with!")
|
|
self.merge_state = MergeState.MERGING
|
|
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
|
|
elif self.merge_state == MergeState.MERGING:
|
|
self.parent().info_label.setText("Merging masks!")
|
|
self.parent().annotator.merge_masks()
|
|
self.merge_state = MergeState.PICKING
|
|
|
|
def on_move_current_mask_background_fn(self):
|
|
self.parent().info_label.setText("Moving current mask to background!")
|
|
self.parent().annotator.move_current_mask_to_background()
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
|
|
def on_remove_hidden_masks(self):
|
|
self.parent().info_label.setText("Removing hidden masks!")
|
|
annotations = self.parent().annotator.masks
|
|
argmax_mask = self.parent().annotator.make_instance_mask()
|
|
limit_ratio = float(self.remove_hidden_masks_line.text())
|
|
new_masks = []
|
|
new_labels = []
|
|
for idx, (mask, label) in enumerate(annotations):
|
|
num_pixels = np.sum(mask > 0)
|
|
num_visible = np.sum(argmax_mask == (idx + 1))
|
|
ratio = num_visible / num_pixels
|
|
|
|
if ratio > limit_ratio:
|
|
new_masks.append(mask)
|
|
new_labels.append(label)
|
|
|
|
print("Removed ", len(annotations) - len(new_masks), " masks.")
|
|
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
|
|
def on_pick_mask(self):
|
|
self.parent().info_label.setText("Pick a mask to do modifications!")
|
|
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
|
|
|
def on_pick_bounding_box(self):
|
|
self.parent().info_label.setText("Pick a bounding box to do modifications!")
|
|
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
|
|
|
|
def on_manual_polygon(self):
|
|
# Sets emphasis on the button
|
|
self.manual_polygon.setProperty("active", True)
|
|
self.parent().image_label.change_paint_type(PaintType.POLYGON)
|
|
|
|
def on_add_point(self):
|
|
self.parent().info_label.setText("Adding point annotation!")
|
|
self.parent().image_label.change_paint_type(PaintType.POINT)
|
|
|
|
def on_add_box(self):
|
|
self.parent().info_label.setText("Adding box annotation!")
|
|
self.parent().image_label.change_paint_type(PaintType.BOX)
|
|
|
|
def on_zoom_rectangle(self):
|
|
if self.zoom_flag:
|
|
self.parent().info_label.setText("Zooming rectangle OFF!")
|
|
self.parent().image_label.change_paint_type(PaintType.POINT)
|
|
self.parent().annotator.zoomed_bounding_box = None
|
|
self.parent().annotator.make_embedding()
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
self.zoom_flag = False
|
|
else:
|
|
self.parent().info_label.setText("Pick Mask to zoom!")
|
|
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
|
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
|
|
self.zoom_flag = True
|
|
|
|
def on_annotate_all(self):
|
|
self.parent().info_label.setText("Annotating all. This make take some time.")
|
|
self.parent().annotator.predict_all(self.annotation_settings.get_values())
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
self.parent().info_label.setText("Annotate all finished.")
|
|
|
|
def on_cancel_annotation(self):
|
|
self.parent().image_label.clear()
|
|
self.parent().annotator.clear_last_masks()
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
|
|
def on_save_annotation(self):
|
|
if self.parent().image_label.paint_type == PaintType.POLYGON:
|
|
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
|
|
self.config.window_size[0], self.config.window_size[1])
|
|
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
self.parent().image_label.clear()
|
|
|