Files
segment-anything-ui-gpu/segment_anything_ui/draw_label.py
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

269 lines
11 KiB
Python

import copy
from enum import Enum
import cv2
import numpy as np
import PySide6
from PySide6 import QtCore, QtWidgets
from PySide6.QtGui import QPainter, QPen
from segment_anything_ui.config import Config
from segment_anything_ui.utils.shapes import BoundingBox, Polygon
class PaintType(Enum):
POINT = 0
BOX = 1
MASK = 2
POLYGON = 3
MASK_PICKER = 4
ZOOM_PICKER = 5
BOX_PICKER = 6
class MaskIdPicker:
def __init__(self, length) -> None:
self.counter = 0
self.length = length
def increment(self):
self.counter = (self.counter + 1) % self.length
def pick(self, ids):
print("Length of ids: ", len(ids), " counter: ", self.counter, " ids: ", ids)
if len(ids) <= self.counter:
self.counter = 0
return_id = ids[self.counter]
self.increment()
return return_id
class DrawLabel(QtWidgets.QLabel):
def __init__(self, parent=None):
super().__init__(parent)
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0)
self._paint_type = PaintType.POINT
self.polygon = Polygon()
self.mask_enum: MaskIdPicker = MaskIdPicker(3)
self.config = Config()
self.setFocusPolicy(QtCore.Qt.StrongFocus)
self._zoom_center = (0, 0)
self._zoom_factor = 1.0
self._zoom_bounding_box: BoundingBox | None = None
def paintEvent(self, paint_event):
painter = QPainter(self)
painter.drawPixmap(self.rect(), self.pixmap())
pen_positive = self._get_pen(QtCore.Qt.green, 5)
pen_negative = self._get_pen(QtCore.Qt.red, 5)
pen_partial = self._get_pen(QtCore.Qt.yellow, 1)
pen_box = self._get_pen(QtCore.Qt.green, 1)
painter.setRenderHint(QPainter.Antialiasing, False)
painter.setPen(pen_box)
if self.bounding_box is not None and self.bounding_box.xend != -1 and self.bounding_box.yend != -1:
painter.drawRect(
self.bounding_box.xstart,
self.bounding_box.ystart,
self.bounding_box.xend - self.bounding_box.xstart,
self.bounding_box.yend - self.bounding_box.ystart
)
painter.setPen(pen_partial)
painter.drawRect(self.partial_box.xstart, self.partial_box.ystart,
self.partial_box.xend - self.partial_box.xstart,
self.partial_box.yend - self.partial_box.ystart)
painter.setPen(pen_positive)
for pos in self.positive_points:
painter.drawPoint(pos)
painter.setPen(pen_negative)
painter.setRenderHint(QPainter.Antialiasing, False)
for pos in self.negative_points:
painter.drawPoint(pos)
if self.polygon.is_plotable():
painter.setPen(pen_box)
painter.setRenderHint(QPainter.Antialiasing, True)
painter.drawPolygon(self.polygon.to_qpolygon())
# self.update()
def _get_pen(self, color=QtCore.Qt.red, width=5):
pen = QPen()
pen.setWidth(width)
pen.setColor(color)
return pen
@property
def paint_type(self):
return self._paint_type
def change_paint_type(self, paint_type: PaintType):
print(f"Changing paint type to {paint_type}")
self._paint_type = paint_type
def mouseMoveEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
self.partial_box = copy.deepcopy(self.bounding_box)
self.partial_box.xend = ev.pos().x()
self.partial_box.yend = ev.pos().y()
self.update()
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.buttons() == QtCore.Qt.LeftButton:
self._move_update(None, point)
elif ev.buttons() == QtCore.Qt.RightButton:
self._move_update(point, None)
else:
pass
self.update()
def _move_update(self, temporary_point_negative, temporary_point_positive):
annotations = self.get_annotations(temporary_point_positive, temporary_point_negative)
self.parent().annotator.make_prediction(annotations)
self.parent().annotator.visualize_last_mask()
def mouseReleaseEvent(self, cursor_event):
if self._paint_type == PaintType.POINT:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.positive_points.append(cursor_event.pos())
print(self.size())
elif cursor_event.button() == QtCore.Qt.RightButton:
self.negative_points.append(cursor_event.pos())
# self.chosen_points.append(self.mapFromGlobal(QtGui.QCursor.pos()))
elif self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.bounding_box.xend = cursor_event.pos().x()
self.bounding_box.yend = cursor_event.pos().y()
self.partial_box = BoundingBox(-1, -1, -1, -1)
if not self._paint_type == PaintType.MASK_PICKER and \
not self._paint_type == PaintType.ZOOM_PICKER and \
not self._paint_type == PaintType.POLYGON and \
not self._paint_type == PaintType.BOX_PICKER:
self.parent().annotator.make_prediction(self.get_annotations())
self.parent().annotator.visualize_last_mask()
if self._paint_type == PaintType.ZOOM_PICKER:
self.parent().annotator.zoomed_bounding_box = self.bounding_box.scale(*self._get_scale()).to_int()
self.bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self._paint_type = PaintType.POINT
self.update()
def mousePressEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER] and ev.button() == QtCore.Qt.LeftButton:
self.bounding_box = BoundingBox(xstart=ev.pos().x(), ystart=ev.pos().y())
if self._paint_type == PaintType.POLYGON and ev.button() == QtCore.Qt.LeftButton:
self.polygon.points.append([ev.pos().x(), ev.pos().y()])
if self._paint_type == PaintType.MASK_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
int(ev.pos().x() / size.width() * self.config.window_size[0]),
int(ev.pos().y() / size.height() * self.config.window_size[1])]
masks = np.array(self.parent().annotator.masks.masks)
mask_ids = np.where(masks[:, point[1], point[0]])[0]
print("Picking mask at point: {}".format(point))
if not(len(mask_ids)):
print("No mask found")
mask_id = -1
local_mask = np.zeros((masks.shape[1], masks.shape[2]))
label = None
else:
mask_id = self.mask_enum.pick(mask_ids)
local_mask = self.parent().annotator.masks.get_mask(mask_id)
label = self.parent().annotator.masks.get_label(mask_id + 1)
self.parent().annotator.masks.mask_id = mask_id
self.parent().annotator.last_mask = local_mask
self.parent().annotator.visualize_last_mask(label)
if self._paint_type == PaintType.BOX_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
float(ev.pos().x() / size.width()),
float(ev.pos().y() / size.height())]
bounding_box, bounding_box_id = self.parent().annotator.bounding_boxes.find_closest_bounding_box(point)
if bounding_box is None:
print("No bounding box found")
else:
self.parent().annotator.bounding_boxes.bounding_box_id = bounding_box_id
print(f"Bounding box: {bounding_box}")
print(f"Bounding box id: {bounding_box_id}")
self.parent().update(self.parent().annotator.merge_image_visualization())
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.button() == QtCore.Qt.LeftButton:
self._move_update(None, point)
if ev.button() == QtCore.Qt.RightButton:
self._move_update(point, None)
self.update()
def zoom_to_rectangle(self, xstart, ystart, xend, yend):
picked_image = self.parent().annotator.image[ystart:yend, xstart:xend, :]
self.parent().annotator.image = cv2.resize(picked_image, (self.config.window_size[0], self.config.window_size[1]))
self.update()
def keyPressEvent(self, ev: PySide6.QtGui.QKeyEvent) -> None:
print(ev.key())
if self._paint_type == PaintType.MASK_PICKER and ev.key() == QtCore.Qt.Key.Key_D and len(self.parent().annotator.masks):
print("Deleting mask")
self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
def _get_scale(self):
return self.config.window_size[0] / self.size().width(), self.config.window_size[1] / self.size().height()
def get_annotations(
self,
temporary_point_positive: PySide6.QtCore.QPoint | None = None,
temporary_point_negative: PySide6.QtCore.QPoint | None = None
):
sx, sy = self._get_scale()
positive_points = [(
p.x() * sx,
p.y() * sy) for p in self.positive_points]
negative_points = [(
p.x() * sx,
p.y() * sy) for p in self.negative_points]
if temporary_point_positive:
positive_points += [(temporary_point_positive.x() * sx, temporary_point_positive.y() * sy)]
if temporary_point_negative:
negative_points += [(temporary_point_negative.x() * sx, temporary_point_negative.y() * sy)]
positive_points = np.array(positive_points).reshape(-1, 2)
negative_points = np.array(negative_points).reshape(-1, 2)
labels = np.array([1, ] * len(positive_points) + [0, ] * len(negative_points))
print(f"Positive points: {positive_points}")
print(f"Negative points: {negative_points}")
print(f"Labels: {labels}")
return {
"points": np.concatenate([positive_points, negative_points], axis=0),
"labels": labels,
"bounding_boxes": self.bounding_box.scale(sx, sy).to_numpy() if self.bounding_box else None
}
def clear(self):
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0, 0, 0)
self.polygon = Polygon()
self.update()