Files
segment-anything-ui-gpu/segment_anything_ui/settings_layout.py
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

228 lines
10 KiB
Python

import json
import os
import pathlib
import random
import cv2
import numpy as np
from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox
from segment_anything_ui.annotator import MasksAnnotation
from segment_anything_ui.config import Config
from segment_anything_ui.utils.bounding_boxes import BoundingBox
class FilesHolder:
def __init__(self):
self.files = []
self.position = 0
def add_files(self, files):
self.files.extend(files)
def get_next(self):
self.position += 1
if self.position >= len(self.files):
self.position = 0
return self.files[self.position]
def get_previous(self):
self.position -= 1
if self.position < 0:
self.position = len(self.files) - 1
return self.files[self.position]
class SettingsLayout(QWidget):
MASK_EXTENSION = "_mask.png"
LABELS_EXTENSION = "_labels.json"
BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json"
def __init__(self, parent, config: Config) -> None:
super().__init__(parent)
self.config = config
self.actual_file: str = ""
self.actual_shape = self.config.window_size
self.layout = QVBoxLayout(self)
self.open_files = QPushButton("Open Files")
self.open_files.clicked.connect(self.on_open_files)
self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]")
self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]")
self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key)
self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]")
self.save_mask.clicked.connect(self.on_save_mask)
self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key)
self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]")
self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes)
self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key)
self.next_file.clicked.connect(self.on_next_file)
self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key)
self.previous_file.clicked.connect(self.on_previous_file)
self.checkpoint_path_label = QLabel(self, text="Checkpoint Path")
self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights)
self.precompute_button = QPushButton("Precompute all embeddings")
self.precompute_button.clicked.connect(self.on_precompute)
self.delete_existing_annotation = QPushButton("Delete existing annotation")
self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation)
self.show_image = QPushButton("Show Image")
self.show_visualization = QPushButton("Show Visualization")
self.show_bounding_boxes = QCheckBox("Show Bounding Boxes")
self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes)
self.show_image.clicked.connect(self.on_show_image)
self.show_visualization.clicked.connect(self.on_show_visualization)
self.show_text = QCheckBox("Show Text")
self.show_text.clicked.connect(self.on_show_text)
self.tag_text_field = QLineEdit(self)
self.tag_text_field.setPlaceholderText("Comma separated image tags")
self.layout.addWidget(self.open_files)
self.layout.addWidget(self.next_file)
self.layout.addWidget(self.previous_file)
self.layout.addWidget(self.save_mask)
self.layout.addWidget(self.save_bounding_boxes)
self.layout.addWidget(self.delete_existing_annotation)
self.layout.addWidget(self.show_text)
self.layout.addWidget(self.show_bounding_boxes)
self.layout.addWidget(self.tag_text_field)
self.layout.addWidget(self.checkpoint_path_label)
self.layout.addWidget(self.checkpoint_path)
self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed)
self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed)
self.layout.addWidget(self.precompute_button)
self.layout.addWidget(self.show_image)
self.layout.addWidget(self.show_visualization)
self.files = FilesHolder()
self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)
def on_delete_existing_annotation(self):
path = os.path.split(self.actual_file)[0]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
if os.path.exists(mask_path):
os.remove(mask_path)
if os.path.exists(labels_path):
os.remove(labels_path)
def is_show_text(self):
return self.show_text.isChecked()
def on_show_text(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_next_file(self):
file = self.files.get_next()
self._load_image(file)
def on_previous_file(self):
file = self.files.get_previous()
self._load_image(file)
def _load_image(self, file: str):
mask = file.split(".")[0] + self.MASK_EXTENSION
labels = file.split(".")[0] + self.LABELS_EXTENSION
bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
self.actual_shape = image.shape[:2][::-1]
self.actual_file = file
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.dtype in [np.float32, np.float64, np.uint16]:
image = (image / np.amax(image) * 255).astype("uint8")
#image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1)
image = cv2.resize(image,
(int(self.parent().config.window_size[0]), self.parent().config.window_size[1]))
self.parent().annotator.clear()
self.parent().image_label.clear()
self.original_image = image.copy()
self.parent().set_image(image)
if os.path.exists(mask) and os.path.exists(labels):
self._load_annotation(mask, labels)
self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
elif os.path.exists(bounding_boxes):
self._load_bounding_boxes(bounding_boxes)
self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
self.parent().info_label.setText(f"No annotation found! Image: {file}")
self.tag_text_field.setText("")
def _load_annotation(self, mask, labels):
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]),
interpolation=cv2.INTER_NEAREST)
with open(labels, "r") as fp:
labels: dict[str, str] = json.load(fp)
masks = []
new_labels = []
if "instances" in labels:
instance_labels = labels["instances"]
else:
instance_labels = labels
if "tags" in labels:
self.tag_text_field.setText(",".join(labels["tags"]))
else:
self.tag_text_field.setText("")
for str_index, class_ in instance_labels.items():
single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
single_mask[mask == int(str_index)] = 255
masks.append(single_mask)
new_labels.append(class_)
self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels)
def _load_bounding_boxes(self, bounding_boxes):
with open(bounding_boxes, "r") as f:
bounding_boxes: list[dict[str, float | str]] = json.load(f)
for bounding_box in bounding_boxes:
self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box))
def on_show_image(self):
self.parent().set_image(self.original_image, clear_annotations=False)
def on_show_visualization(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_precompute(self):
pass
def on_save_mask(self):
path = os.path.split(self.actual_file)[0]
tags = self.tag_text_field.text().split(",")
tags = [tag.strip() for tag in tags]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
masks = self.parent().get_mask()
labels = {"instances": self.parent().get_labels(), "tags": tags}
with open(labels_path, "w") as f:
json.dump(labels, f, indent=4)
masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST)
cv2.imwrite(mask_path, masks)
def on_checkpoint_path_changed(self):
self.parent().sam = self.parent().init_sam()
def on_open_files(self):
files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)")
random.shuffle(files)
self.files.add_files(files)
self.on_next_file()
def on_save_bounding_boxes(self):
path = os.path.split(self.actual_file)[0]
basename = pathlib.Path(self.actual_file).stem
bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION)
bounding_boxes = self.parent().get_bounding_boxes()
bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes]
with open(bounding_boxes_path, "w") as f:
json.dump(bounding_boxes_dict, f, indent=4)
def is_show_bounding_boxes(self):
return self.show_bounding_boxes.isChecked()
def on_show_bounding_boxes(self):
self.parent().update(self.parent().annotator.merge_image_visualization())