Files
segment-anything-ui-gpu/segment_anything_ui/annotation_layout.py
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

219 lines
11 KiB
Python

import enum
import json
import os
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
from segment_anything_ui.draw_label import PaintType
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
class MergeState(enum.Enum):
PICKING = enum.auto()
MERGING = enum.auto()
class AnnotationLayout(QWidget):
def __init__(self, parent, config) -> None:
super().__init__(parent)
self.config = config
self.zoom_flag = False
self.merge_state = MergeState.PICKING
self.layout = QVBoxLayout(self)
labels = self._load_labels(config)
self.layout.setAlignment(Qt.AlignTop)
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.label_picker = QListWidget()
self.label_picker.addItems(labels)
self.label_picker.setCurrentRow(0)
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
self.remove_hidden_masks_line = QLineEdit("0.5")
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
for w in [
self.add_point,
self.add_box,
self.annotate_all,
self.pick_mask,
self.pick_bounding_box,
self.merge_masks,
self.move_current_mask_background,
self.cancel_annotation,
self.delete_mask,
self.partial_annotation,
self.save_annotation,
self.manual_polygon,
self.label_picker,
self.zoom_rectangle,
self.annotation_settings,
self.remove_hidden_masks,
self.remove_hidden_masks_label,
self.remove_hidden_masks_line
]:
self.layout.addWidget(w)
self.add_point.clicked.connect(self.on_add_point)
self.add_box.clicked.connect(self.on_add_box)
self.annotate_all.clicked.connect(self.on_annotate_all)
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
self.save_annotation.clicked.connect(self.on_save_annotation)
self.pick_mask.clicked.connect(self.on_pick_mask)
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
self.manual_polygon.clicked.connect(self.on_manual_polygon)
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
self.merge_masks.clicked.connect(self.on_merge_masks)
self.partial_annotation.clicked.connect(self.on_partial_annotation)
self.delete_mask.clicked.connect(self.on_delete_mask)
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
def on_delete_mask(self):
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
self.parent().info_label.setText("Deleting mask!")
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.bounding_boxes.remove(mask_uid)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
self.parent().info_label.setText("Deleting bounding box!")
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
self.parent().annotator.bounding_boxes.bounding_box_id)
if mask_uid is not None:
self.parent().annotator.masks.pop_by_uuid(mask_uid)
self.parent().annotator.bounding_boxes.bounding_box_id = -1
self.parent().annotator.last_mask = None
self.parent().annotator.masks.mask_id = -1
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
def on_partial_annotation(self):
self.parent().info_label.setText("Partial annotation!")
self.parent().annotator.pick_partial_mask()
self.parent().image_label.clear()
@staticmethod
def _load_labels(config):
if not os.path.exists(config.label_file):
return ["default"]
with open(config.label_file, "r") as f:
labels = json.load(f)
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
return labels
def on_merge_masks(self):
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
if self.merge_state == MergeState.PICKING:
self.parent().info_label.setText("Pick a mask to merge with!")
self.merge_state = MergeState.MERGING
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
elif self.merge_state == MergeState.MERGING:
self.parent().info_label.setText("Merging masks!")
self.parent().annotator.merge_masks()
self.merge_state = MergeState.PICKING
def on_move_current_mask_background_fn(self):
self.parent().info_label.setText("Moving current mask to background!")
self.parent().annotator.move_current_mask_to_background()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_remove_hidden_masks(self):
self.parent().info_label.setText("Removing hidden masks!")
annotations = self.parent().annotator.masks
argmax_mask = self.parent().annotator.make_instance_mask()
limit_ratio = float(self.remove_hidden_masks_line.text())
new_masks = []
new_labels = []
for idx, (mask, label) in enumerate(annotations):
num_pixels = np.sum(mask > 0)
num_visible = np.sum(argmax_mask == (idx + 1))
ratio = num_visible / num_pixels
if ratio > limit_ratio:
new_masks.append(mask)
new_labels.append(label)
print("Removed ", len(annotations) - len(new_masks), " masks.")
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_pick_mask(self):
self.parent().info_label.setText("Pick a mask to do modifications!")
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
def on_pick_bounding_box(self):
self.parent().info_label.setText("Pick a bounding box to do modifications!")
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
def on_manual_polygon(self):
# Sets emphasis on the button
self.manual_polygon.setProperty("active", True)
self.parent().image_label.change_paint_type(PaintType.POLYGON)
def on_add_point(self):
self.parent().info_label.setText("Adding point annotation!")
self.parent().image_label.change_paint_type(PaintType.POINT)
def on_add_box(self):
self.parent().info_label.setText("Adding box annotation!")
self.parent().image_label.change_paint_type(PaintType.BOX)
def on_zoom_rectangle(self):
if self.zoom_flag:
self.parent().info_label.setText("Zooming rectangle OFF!")
self.parent().image_label.change_paint_type(PaintType.POINT)
self.parent().annotator.zoomed_bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self.zoom_flag = False
else:
self.parent().info_label.setText("Pick Mask to zoom!")
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
self.zoom_flag = True
def on_annotate_all(self):
self.parent().info_label.setText("Annotating all. This make take some time.")
self.parent().annotator.predict_all(self.annotation_settings.get_values())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().info_label.setText("Annotate all finished.")
def on_cancel_annotation(self):
self.parent().image_label.clear()
self.parent().annotator.clear_last_masks()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_save_annotation(self):
if self.parent().image_label.paint_type == PaintType.POLYGON:
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
self.config.window_size[0], self.config.window_size[1])
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().image_label.clear()