initial_tune
This commit is contained in:
218
segment_anything_ui/annotation_layout.py
Normal file
218
segment_anything_ui/annotation_layout.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
|
||||
|
||||
from segment_anything_ui.draw_label import PaintType
|
||||
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
|
||||
|
||||
|
||||
class MergeState(enum.Enum):
|
||||
PICKING = enum.auto()
|
||||
MERGING = enum.auto()
|
||||
|
||||
|
||||
class AnnotationLayout(QWidget):
|
||||
|
||||
def __init__(self, parent, config) -> None:
|
||||
super().__init__(parent)
|
||||
self.config = config
|
||||
self.zoom_flag = False
|
||||
self.merge_state = MergeState.PICKING
|
||||
self.layout = QVBoxLayout(self)
|
||||
labels = self._load_labels(config)
|
||||
self.layout.setAlignment(Qt.AlignTop)
|
||||
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
|
||||
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
|
||||
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
|
||||
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
|
||||
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
|
||||
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
|
||||
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
|
||||
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
|
||||
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
|
||||
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
|
||||
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
|
||||
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
||||
self.label_picker = QListWidget()
|
||||
self.label_picker.addItems(labels)
|
||||
self.label_picker.setCurrentRow(0)
|
||||
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
|
||||
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
|
||||
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
|
||||
self.remove_hidden_masks_line = QLineEdit("0.5")
|
||||
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
|
||||
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
|
||||
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
|
||||
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
|
||||
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
|
||||
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
|
||||
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
|
||||
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
|
||||
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
|
||||
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
|
||||
|
||||
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
|
||||
for w in [
|
||||
self.add_point,
|
||||
self.add_box,
|
||||
self.annotate_all,
|
||||
self.pick_mask,
|
||||
self.pick_bounding_box,
|
||||
self.merge_masks,
|
||||
self.move_current_mask_background,
|
||||
self.cancel_annotation,
|
||||
self.delete_mask,
|
||||
self.partial_annotation,
|
||||
self.save_annotation,
|
||||
self.manual_polygon,
|
||||
self.label_picker,
|
||||
self.zoom_rectangle,
|
||||
self.annotation_settings,
|
||||
self.remove_hidden_masks,
|
||||
self.remove_hidden_masks_label,
|
||||
self.remove_hidden_masks_line
|
||||
]:
|
||||
self.layout.addWidget(w)
|
||||
self.add_point.clicked.connect(self.on_add_point)
|
||||
self.add_box.clicked.connect(self.on_add_box)
|
||||
self.annotate_all.clicked.connect(self.on_annotate_all)
|
||||
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
|
||||
self.save_annotation.clicked.connect(self.on_save_annotation)
|
||||
self.pick_mask.clicked.connect(self.on_pick_mask)
|
||||
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
|
||||
self.manual_polygon.clicked.connect(self.on_manual_polygon)
|
||||
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
|
||||
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
|
||||
self.merge_masks.clicked.connect(self.on_merge_masks)
|
||||
self.partial_annotation.clicked.connect(self.on_partial_annotation)
|
||||
self.delete_mask.clicked.connect(self.on_delete_mask)
|
||||
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
|
||||
|
||||
def on_delete_mask(self):
|
||||
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
|
||||
self.parent().info_label.setText("Deleting mask!")
|
||||
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
|
||||
self.parent().annotator.bounding_boxes.remove(mask_uid)
|
||||
self.parent().annotator.masks.mask_id = -1
|
||||
self.parent().annotator.last_mask = None
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
|
||||
self.parent().info_label.setText("Deleting bounding box!")
|
||||
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
|
||||
self.parent().annotator.bounding_boxes.bounding_box_id)
|
||||
if mask_uid is not None:
|
||||
self.parent().annotator.masks.pop_by_uuid(mask_uid)
|
||||
self.parent().annotator.bounding_boxes.bounding_box_id = -1
|
||||
self.parent().annotator.last_mask = None
|
||||
self.parent().annotator.masks.mask_id = -1
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
else:
|
||||
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
|
||||
|
||||
def on_partial_annotation(self):
|
||||
self.parent().info_label.setText("Partial annotation!")
|
||||
self.parent().annotator.pick_partial_mask()
|
||||
self.parent().image_label.clear()
|
||||
|
||||
@staticmethod
|
||||
def _load_labels(config):
|
||||
if not os.path.exists(config.label_file):
|
||||
return ["default"]
|
||||
with open(config.label_file, "r") as f:
|
||||
labels = json.load(f)
|
||||
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
|
||||
return labels
|
||||
|
||||
def on_merge_masks(self):
|
||||
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
||||
if self.merge_state == MergeState.PICKING:
|
||||
self.parent().info_label.setText("Pick a mask to merge with!")
|
||||
self.merge_state = MergeState.MERGING
|
||||
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
|
||||
elif self.merge_state == MergeState.MERGING:
|
||||
self.parent().info_label.setText("Merging masks!")
|
||||
self.parent().annotator.merge_masks()
|
||||
self.merge_state = MergeState.PICKING
|
||||
|
||||
def on_move_current_mask_background_fn(self):
|
||||
self.parent().info_label.setText("Moving current mask to background!")
|
||||
self.parent().annotator.move_current_mask_to_background()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_remove_hidden_masks(self):
|
||||
self.parent().info_label.setText("Removing hidden masks!")
|
||||
annotations = self.parent().annotator.masks
|
||||
argmax_mask = self.parent().annotator.make_instance_mask()
|
||||
limit_ratio = float(self.remove_hidden_masks_line.text())
|
||||
new_masks = []
|
||||
new_labels = []
|
||||
for idx, (mask, label) in enumerate(annotations):
|
||||
num_pixels = np.sum(mask > 0)
|
||||
num_visible = np.sum(argmax_mask == (idx + 1))
|
||||
ratio = num_visible / num_pixels
|
||||
|
||||
if ratio > limit_ratio:
|
||||
new_masks.append(mask)
|
||||
new_labels.append(label)
|
||||
|
||||
print("Removed ", len(annotations) - len(new_masks), " masks.")
|
||||
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_pick_mask(self):
|
||||
self.parent().info_label.setText("Pick a mask to do modifications!")
|
||||
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
|
||||
|
||||
def on_pick_bounding_box(self):
|
||||
self.parent().info_label.setText("Pick a bounding box to do modifications!")
|
||||
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
|
||||
|
||||
def on_manual_polygon(self):
|
||||
# Sets emphasis on the button
|
||||
self.manual_polygon.setProperty("active", True)
|
||||
self.parent().image_label.change_paint_type(PaintType.POLYGON)
|
||||
|
||||
def on_add_point(self):
|
||||
self.parent().info_label.setText("Adding point annotation!")
|
||||
self.parent().image_label.change_paint_type(PaintType.POINT)
|
||||
|
||||
def on_add_box(self):
|
||||
self.parent().info_label.setText("Adding box annotation!")
|
||||
self.parent().image_label.change_paint_type(PaintType.BOX)
|
||||
|
||||
def on_zoom_rectangle(self):
|
||||
if self.zoom_flag:
|
||||
self.parent().info_label.setText("Zooming rectangle OFF!")
|
||||
self.parent().image_label.change_paint_type(PaintType.POINT)
|
||||
self.parent().annotator.zoomed_bounding_box = None
|
||||
self.parent().annotator.make_embedding()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.zoom_flag = False
|
||||
else:
|
||||
self.parent().info_label.setText("Pick Mask to zoom!")
|
||||
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
|
||||
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
|
||||
self.zoom_flag = True
|
||||
|
||||
def on_annotate_all(self):
|
||||
self.parent().info_label.setText("Annotating all. This make take some time.")
|
||||
self.parent().annotator.predict_all(self.annotation_settings.get_values())
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.parent().info_label.setText("Annotate all finished.")
|
||||
|
||||
def on_cancel_annotation(self):
|
||||
self.parent().image_label.clear()
|
||||
self.parent().annotator.clear_last_masks()
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
|
||||
def on_save_annotation(self):
|
||||
if self.parent().image_label.paint_type == PaintType.POLYGON:
|
||||
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
|
||||
self.config.window_size[0], self.config.window_size[1])
|
||||
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
|
||||
self.parent().update(self.parent().annotator.merge_image_visualization())
|
||||
self.parent().image_label.clear()
|
||||
|
||||
Reference in New Issue
Block a user