initial_tune

This commit is contained in:
AI-team\cyhan
2025-05-12 11:23:49 +09:00
commit b436a81091
33 changed files with 2398 additions and 0 deletions

View File

@@ -0,0 +1,218 @@
import enum
import json
import os
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
from segment_anything_ui.draw_label import PaintType
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
class MergeState(enum.Enum):
PICKING = enum.auto()
MERGING = enum.auto()
class AnnotationLayout(QWidget):
def __init__(self, parent, config) -> None:
super().__init__(parent)
self.config = config
self.zoom_flag = False
self.merge_state = MergeState.PICKING
self.layout = QVBoxLayout(self)
labels = self._load_labels(config)
self.layout.setAlignment(Qt.AlignTop)
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.label_picker = QListWidget()
self.label_picker.addItems(labels)
self.label_picker.setCurrentRow(0)
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
self.remove_hidden_masks_line = QLineEdit("0.5")
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
for w in [
self.add_point,
self.add_box,
self.annotate_all,
self.pick_mask,
self.pick_bounding_box,
self.merge_masks,
self.move_current_mask_background,
self.cancel_annotation,
self.delete_mask,
self.partial_annotation,
self.save_annotation,
self.manual_polygon,
self.label_picker,
self.zoom_rectangle,
self.annotation_settings,
self.remove_hidden_masks,
self.remove_hidden_masks_label,
self.remove_hidden_masks_line
]:
self.layout.addWidget(w)
self.add_point.clicked.connect(self.on_add_point)
self.add_box.clicked.connect(self.on_add_box)
self.annotate_all.clicked.connect(self.on_annotate_all)
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
self.save_annotation.clicked.connect(self.on_save_annotation)
self.pick_mask.clicked.connect(self.on_pick_mask)
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
self.manual_polygon.clicked.connect(self.on_manual_polygon)
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
self.merge_masks.clicked.connect(self.on_merge_masks)
self.partial_annotation.clicked.connect(self.on_partial_annotation)
self.delete_mask.clicked.connect(self.on_delete_mask)
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
def on_delete_mask(self):
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
self.parent().info_label.setText("Deleting mask!")
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.bounding_boxes.remove(mask_uid)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
self.parent().info_label.setText("Deleting bounding box!")
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
self.parent().annotator.bounding_boxes.bounding_box_id)
if mask_uid is not None:
self.parent().annotator.masks.pop_by_uuid(mask_uid)
self.parent().annotator.bounding_boxes.bounding_box_id = -1
self.parent().annotator.last_mask = None
self.parent().annotator.masks.mask_id = -1
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
def on_partial_annotation(self):
self.parent().info_label.setText("Partial annotation!")
self.parent().annotator.pick_partial_mask()
self.parent().image_label.clear()
@staticmethod
def _load_labels(config):
if not os.path.exists(config.label_file):
return ["default"]
with open(config.label_file, "r") as f:
labels = json.load(f)
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
return labels
def on_merge_masks(self):
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
if self.merge_state == MergeState.PICKING:
self.parent().info_label.setText("Pick a mask to merge with!")
self.merge_state = MergeState.MERGING
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
elif self.merge_state == MergeState.MERGING:
self.parent().info_label.setText("Merging masks!")
self.parent().annotator.merge_masks()
self.merge_state = MergeState.PICKING
def on_move_current_mask_background_fn(self):
self.parent().info_label.setText("Moving current mask to background!")
self.parent().annotator.move_current_mask_to_background()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_remove_hidden_masks(self):
self.parent().info_label.setText("Removing hidden masks!")
annotations = self.parent().annotator.masks
argmax_mask = self.parent().annotator.make_instance_mask()
limit_ratio = float(self.remove_hidden_masks_line.text())
new_masks = []
new_labels = []
for idx, (mask, label) in enumerate(annotations):
num_pixels = np.sum(mask > 0)
num_visible = np.sum(argmax_mask == (idx + 1))
ratio = num_visible / num_pixels
if ratio > limit_ratio:
new_masks.append(mask)
new_labels.append(label)
print("Removed ", len(annotations) - len(new_masks), " masks.")
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_pick_mask(self):
self.parent().info_label.setText("Pick a mask to do modifications!")
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
def on_pick_bounding_box(self):
self.parent().info_label.setText("Pick a bounding box to do modifications!")
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
def on_manual_polygon(self):
# Sets emphasis on the button
self.manual_polygon.setProperty("active", True)
self.parent().image_label.change_paint_type(PaintType.POLYGON)
def on_add_point(self):
self.parent().info_label.setText("Adding point annotation!")
self.parent().image_label.change_paint_type(PaintType.POINT)
def on_add_box(self):
self.parent().info_label.setText("Adding box annotation!")
self.parent().image_label.change_paint_type(PaintType.BOX)
def on_zoom_rectangle(self):
if self.zoom_flag:
self.parent().info_label.setText("Zooming rectangle OFF!")
self.parent().image_label.change_paint_type(PaintType.POINT)
self.parent().annotator.zoomed_bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self.zoom_flag = False
else:
self.parent().info_label.setText("Pick Mask to zoom!")
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
self.zoom_flag = True
def on_annotate_all(self):
self.parent().info_label.setText("Annotating all. This make take some time.")
self.parent().annotator.predict_all(self.annotation_settings.get_values())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().info_label.setText("Annotate all finished.")
def on_cancel_annotation(self):
self.parent().image_label.clear()
self.parent().annotator.clear_last_masks()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_save_annotation(self):
if self.parent().image_label.paint_type == PaintType.POLYGON:
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
self.config.window_size[0], self.config.window_size[1])
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().image_label.clear()