228 lines
10 KiB
Python
228 lines
10 KiB
Python
import json
|
|
import os
|
|
import pathlib
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox
|
|
|
|
from segment_anything_ui.annotator import MasksAnnotation
|
|
from segment_anything_ui.config import Config
|
|
from segment_anything_ui.utils.bounding_boxes import BoundingBox
|
|
|
|
|
|
class FilesHolder:
|
|
def __init__(self):
|
|
self.files = []
|
|
self.position = 0
|
|
|
|
def add_files(self, files):
|
|
self.files.extend(files)
|
|
|
|
def get_next(self):
|
|
self.position += 1
|
|
if self.position >= len(self.files):
|
|
self.position = 0
|
|
return self.files[self.position]
|
|
|
|
def get_previous(self):
|
|
self.position -= 1
|
|
|
|
if self.position < 0:
|
|
self.position = len(self.files) - 1
|
|
return self.files[self.position]
|
|
|
|
|
|
class SettingsLayout(QWidget):
|
|
MASK_EXTENSION = "_mask.png"
|
|
LABELS_EXTENSION = "_labels.json"
|
|
BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json"
|
|
|
|
def __init__(self, parent, config: Config) -> None:
|
|
super().__init__(parent)
|
|
self.config = config
|
|
self.actual_file: str = ""
|
|
self.actual_shape = self.config.window_size
|
|
self.layout = QVBoxLayout(self)
|
|
self.open_files = QPushButton("Open Files")
|
|
self.open_files.clicked.connect(self.on_open_files)
|
|
self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]")
|
|
self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]")
|
|
self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key)
|
|
self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]")
|
|
self.save_mask.clicked.connect(self.on_save_mask)
|
|
self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key)
|
|
self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]")
|
|
self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes)
|
|
self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key)
|
|
self.next_file.clicked.connect(self.on_next_file)
|
|
self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key)
|
|
self.previous_file.clicked.connect(self.on_previous_file)
|
|
self.checkpoint_path_label = QLabel(self, text="Checkpoint Path")
|
|
self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights)
|
|
self.precompute_button = QPushButton("Precompute all embeddings")
|
|
self.precompute_button.clicked.connect(self.on_precompute)
|
|
self.delete_existing_annotation = QPushButton("Delete existing annotation")
|
|
self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation)
|
|
self.show_image = QPushButton("Show Image")
|
|
self.show_visualization = QPushButton("Show Visualization")
|
|
self.show_bounding_boxes = QCheckBox("Show Bounding Boxes")
|
|
self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes)
|
|
self.show_image.clicked.connect(self.on_show_image)
|
|
self.show_visualization.clicked.connect(self.on_show_visualization)
|
|
self.show_text = QCheckBox("Show Text")
|
|
self.show_text.clicked.connect(self.on_show_text)
|
|
self.tag_text_field = QLineEdit(self)
|
|
self.tag_text_field.setPlaceholderText("Comma separated image tags")
|
|
self.layout.addWidget(self.open_files)
|
|
self.layout.addWidget(self.next_file)
|
|
self.layout.addWidget(self.previous_file)
|
|
self.layout.addWidget(self.save_mask)
|
|
self.layout.addWidget(self.save_bounding_boxes)
|
|
self.layout.addWidget(self.delete_existing_annotation)
|
|
self.layout.addWidget(self.show_text)
|
|
self.layout.addWidget(self.show_bounding_boxes)
|
|
self.layout.addWidget(self.tag_text_field)
|
|
self.layout.addWidget(self.checkpoint_path_label)
|
|
self.layout.addWidget(self.checkpoint_path)
|
|
self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed)
|
|
self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed)
|
|
self.layout.addWidget(self.precompute_button)
|
|
self.layout.addWidget(self.show_image)
|
|
self.layout.addWidget(self.show_visualization)
|
|
self.files = FilesHolder()
|
|
self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)
|
|
|
|
def on_delete_existing_annotation(self):
|
|
path = os.path.split(self.actual_file)[0]
|
|
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
|
|
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
|
|
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
|
|
if os.path.exists(mask_path):
|
|
os.remove(mask_path)
|
|
if os.path.exists(labels_path):
|
|
os.remove(labels_path)
|
|
|
|
def is_show_text(self):
|
|
return self.show_text.isChecked()
|
|
|
|
def on_show_text(self):
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
|
|
def on_next_file(self):
|
|
file = self.files.get_next()
|
|
self._load_image(file)
|
|
|
|
def on_previous_file(self):
|
|
file = self.files.get_previous()
|
|
self._load_image(file)
|
|
|
|
def _load_image(self, file: str):
|
|
mask = file.split(".")[0] + self.MASK_EXTENSION
|
|
labels = file.split(".")[0] + self.LABELS_EXTENSION
|
|
bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION
|
|
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
|
self.actual_shape = image.shape[:2][::-1]
|
|
self.actual_file = file
|
|
if len(image.shape) == 2:
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
else:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
if image.dtype in [np.float32, np.float64, np.uint16]:
|
|
image = (image / np.amax(image) * 255).astype("uint8")
|
|
#image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1)
|
|
image = cv2.resize(image,
|
|
(int(self.parent().config.window_size[0]), self.parent().config.window_size[1]))
|
|
self.parent().annotator.clear()
|
|
self.parent().image_label.clear()
|
|
self.original_image = image.copy()
|
|
self.parent().set_image(image)
|
|
if os.path.exists(mask) and os.path.exists(labels):
|
|
self._load_annotation(mask, labels)
|
|
self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}")
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
elif os.path.exists(bounding_boxes):
|
|
self._load_bounding_boxes(bounding_boxes)
|
|
self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}")
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
else:
|
|
self.parent().info_label.setText(f"No annotation found! Image: {file}")
|
|
self.tag_text_field.setText("")
|
|
|
|
def _load_annotation(self, mask, labels):
|
|
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
|
mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]),
|
|
interpolation=cv2.INTER_NEAREST)
|
|
with open(labels, "r") as fp:
|
|
labels: dict[str, str] = json.load(fp)
|
|
masks = []
|
|
new_labels = []
|
|
if "instances" in labels:
|
|
instance_labels = labels["instances"]
|
|
else:
|
|
instance_labels = labels
|
|
|
|
if "tags" in labels:
|
|
self.tag_text_field.setText(",".join(labels["tags"]))
|
|
else:
|
|
self.tag_text_field.setText("")
|
|
for str_index, class_ in instance_labels.items():
|
|
single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
|
|
single_mask[mask == int(str_index)] = 255
|
|
masks.append(single_mask)
|
|
new_labels.append(class_)
|
|
self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels)
|
|
|
|
def _load_bounding_boxes(self, bounding_boxes):
|
|
with open(bounding_boxes, "r") as f:
|
|
bounding_boxes: list[dict[str, float | str]] = json.load(f)
|
|
for bounding_box in bounding_boxes:
|
|
self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box))
|
|
|
|
def on_show_image(self):
|
|
self.parent().set_image(self.original_image, clear_annotations=False)
|
|
|
|
def on_show_visualization(self):
|
|
self.parent().update(self.parent().annotator.merge_image_visualization())
|
|
|
|
def on_precompute(self):
|
|
pass
|
|
|
|
def on_save_mask(self):
|
|
path = os.path.split(self.actual_file)[0]
|
|
tags = self.tag_text_field.text().split(",")
|
|
tags = [tag.strip() for tag in tags]
|
|
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
|
|
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
|
|
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
|
|
masks = self.parent().get_mask()
|
|
labels = {"instances": self.parent().get_labels(), "tags": tags}
|
|
with open(labels_path, "w") as f:
|
|
json.dump(labels, f, indent=4)
|
|
masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST)
|
|
cv2.imwrite(mask_path, masks)
|
|
|
|
def on_checkpoint_path_changed(self):
|
|
self.parent().sam = self.parent().init_sam()
|
|
|
|
def on_open_files(self):
|
|
files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)")
|
|
random.shuffle(files)
|
|
self.files.add_files(files)
|
|
self.on_next_file()
|
|
|
|
def on_save_bounding_boxes(self):
|
|
path = os.path.split(self.actual_file)[0]
|
|
basename = pathlib.Path(self.actual_file).stem
|
|
bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION)
|
|
bounding_boxes = self.parent().get_bounding_boxes()
|
|
bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes]
|
|
with open(bounding_boxes_path, "w") as f:
|
|
json.dump(bounding_boxes_dict, f, indent=4)
|
|
|
|
def is_show_bounding_boxes(self):
|
|
return self.show_bounding_boxes.isChecked()
|
|
|
|
def on_show_bounding_boxes(self):
|
|
self.parent().update(self.parent().annotator.merge_image_visualization()) |