commit b436a81091e581da4775f1b2acfd1d4388907df2 Author: AI-team\cyhan Date: Mon May 12 11:23:49 2025 +0900 initial_tune diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..12c6135 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.pyc +data/ +checkpoints/ +__pycache__/ +*.pt +*.pth +.venv/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b6864ff --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Branislav Hesko + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c2a5295 --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# Segment Anything UI +Simple UI for the model: [Segment anything](https://github.com/facebookresearch/segment-anything) from Facebook. + + +Segment anything UI for annotations +![GUI](./assets/example.png) + + + +# Usage + + 1. Install segment-anything python package from Github: [Segment anything](https://github.com/facebookresearch/segment-anything). Usually it is enough to run: ```pip install git+https://github.com/facebookresearch/segment-anything.git```. + 2. Download checkpoint [Checkpoint_Huge](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) or [Checkpoint_Large](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) or [Checkpoint_Base](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) and put it into workspace folder. + 3. Fill default_path in ```segment_anything_ui/config.py```. + 4. (Optional) Install efficientnet models ```pip install git+https://github.com/mit-han-lab/efficientvit```. See note below for Windows users about installing onnx! + 5. (Optional) Install sam_hq models ```pip install segment-anything-hq``` + 5. Install requirements.txt. ```pip install -r requirements.txt```. + 6. If on Ubuntu or Debian based distro, please use the following ```apt install libxkbcommon-x11-0 qt5dxcb-plugin libxcb-cursor0```. This will fix issues with Qt. + 7. ```export PYTHONPATH=$PYTHONPATH:.```. + 8. ```python segment_anything_ui/main_window.py```. + +Currently, for saving a simple format is used: mask is saved as .png file, when masks are represented by values: 1 ... n and corresponding labels are saved as jsons. In json, labels are a map with mapping: MASK_ID: LABEL. MASK_ID is the id of the stored mask and LABEL is one of "labels.json" files. + + +``` +For windows users, sometimes you will observe onnx used in EfficientVit is not easy to install using pip. In that case, it may be caused by +https://stackoverflow.com/questions/72352528/how-to-fix-winerror-206-the-filename-or-extension-is-too-long-error/76452218#76452218 + +To fix this error on your Windows machine on regedit and navigate to +Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem and edit LongPathsEnabled and set value from 0 to 1 + +Finally, fix onnx version, newest version seems to be broken +pip install onnx==1.15.0 +``` + + +# Checkpoints +Checkpoints are downloaded automatically if the model is not found in the workspace folder. + +### SAM +- `vit_b`: [ViT-B SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) +- `vit_h`: [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) +- `vit_l`: [ViT-L SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) + +### HQ-SAM +- `vit_b`: [ViT-B HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_b.pth) +- `vit_l`: [ViT-L HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_l.pth) +- `vit_h`: [ViT-H HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_h.pth) +- `vit_tiny` (**Light HQ-SAM** for real-time need): [ViT-Tiny HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_tiny.pth) + +### EfficientViT +- `xl0`: [EfficientViT-XL0 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt) +- `xl1`: [EfficientViT-XL1 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt) +- `l2`: [EfficientViT-L2 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt) + +### SAM2 +- `sam2.1_hiera_t`: [SAM2.1 Hiera Tiny model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_tiny.pt) +- `sam2.1_hiera_l`: [SAM2.1 Hiera Small model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_small.pt) +- `sam2.1_hiera_b+`: [SAM2.1 Hiera Base+ model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_base_plus.pt) +- `sam2.1_hiera_s`: [SAM2.1 Hiera Large model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_large.pt) + +# Functions +There are multiple functions that this UI implements. Few of them are: + + * Add points - by left click of mouse button, you can add a positive point. By right click, you can add a negative point. + * Add boxes - a bounding box can be added to SAM annotation when the proper annotation tool is selected. Manual points, boxes and polygons in the future are used for SAM prediction. + * Add manual polygons - by clicking in the image using left mouse button (and selected manual annotation) manual annotation is done. It does not provide any other features right now. + * Instance mask is saved by clicking on "Save Annotation" button. + * Turn on bounding boxes - by clicking on "Turn on bounding boxes" button, you can turn on bounding boxes for the image. + * Save only bounding boxes - by clicking on "Save only bounding boxes" button, you can save only bounding boxes for the image. + * Annotate All - uses SAM to predict all masks by prompting the model with a grid of points followed by post-processing to refine masks. + * Pick Mask - Select a Mask from the image to delete it or inspect. Modifications are currently not allowed. As annotator allows multiple instances assigned to a pixel, Left clicking on the pixel cyclically changes between assigned masks. + * Each Instance mask is assigned a class according to the chosen label in the list. This list is loaded from labels.json file. + * Masks are inherently ordered based on the time of their creation, with earlier masks being dominant. Therefore masks that were annotated sooner are always present in the final mask, which is in form of an image with values 0 - N for N object instances. In the saved annotation, a pixel has only one value. A mask that is supposed to be used in the final annotation can be moved to the front by picking the desired mask and clicking on "Move Current Mask to Front" button. This is especially useful for "Annotate All" function. + * As each pixel can have multiple masks assigned, a lot of masks may be hidden. In this case only a few pixel could afterwards be present in the final annotation (mostly on the border) and therefore "Remove Hidden Masks" button can be used. This button removes all masks with hidden pixels (they are added later than the dominant mask) with a visible IOU less threshold. + * Masks annotations that have been "picked" can be deleted by Cancel current annotation button. + * Cancel annotation - Cancel current annotation - all points, bounding boxes etc... + * Partial mask - Some objects are hard to be automatically annotated - partial mask allows annotating a single instance by parts: each time a partial instance mask is annotated by clicking on the corresponding button the partial mask is enlarged and merged. Finally the final instance mask is given as a union of all partial masks. + * Zoom - With zoom tool a user can zoom onto some part of the image, annotate it and then this annotation is propagated to the whole image. + * Tags - Each image can have a tag assigned to it. This tag is saved in the annotation file and can be used for filtering images. Use comma separated values for multiple tags. + * When holding MOUSE key, mask proposal is shown. This is useful for eliminating duplicite clicking. + +# Buttons + +Please Note that buttons are fully configurable. You can change bindings and add new buttons in **segment_anything_ui/config.py**. + +| **Button** | **Description** | **Shortcut** | +| --- | --- | --- | +| Add Points | Mouse click will add positive (left) or negative (right) point. | W | +| Add Box | Mouse click will add a box. | Q | +| Annotate All | Runs regular grid annotation with parameters from the form | Enter | +| Pick Mask | Pick mask when clicking on it. Cycling through masks if pixel belongs to multiple masks. | X | +| Merge Masks | WIP: Merge masks when clicking on them. Cycling through masks if pixel belongs to multiple masks. | Z | +| Move Current Mask to Front | Use current mask as background (less important) | None | +| Cancel Annotation | Cancel current annotation | C | +| Save Annotation | Save current annotation | S | +| Manual Polygon | Draw polygon with mouse | R | +| Partial Mask | Allows to split a single object into multiple prompts. When pressed, partial mask is stored and summed with the following promped masks. | D | +| Remove Hidden Masks | Remove masks that have hidden pixels. | None | +| Zoom to Rectangle | Zoom in on the image. | E | +| ---- | ---- | ---- | +| Open Files | Load files from the folder | None | +| Next File | Load next image | F | +| Previous File | Load previous image | G | +| ---- | ---- | ---- | +| Precompute All Embeddings | Currently not implemented | None | +| Show Image | Currently not implemented | None | +| Show Visualization | Currently not implemented | None | + + +# TODO: + + - [x] - FIX: mouse picker for small objects is not precise. + - [ ] - Region merging. + - [x] - Manual annotation. + - [x] - Saving and loading of masks. + - [x] - Class support for assigning classes to objects. + - [x] - Add object borders. + - [x] - Fix mask size and QLabel size for precise mouse clicks. + - [ ] - Draft mask when no points are visible. + - [x] - Box zoom support. \ No newline at end of file diff --git a/assets/example.png b/assets/example.png new file mode 100644 index 0000000..5b67775 Binary files /dev/null and b/assets/example.png differ diff --git a/install_script b/install_script new file mode 100644 index 0000000..3cbbb8c --- /dev/null +++ b/install_script @@ -0,0 +1,6 @@ +uv venv +./venv/bin/activate +uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +uv pip install git+https://github.com/facebookresearch/segment-anything.git +uv pip install git+https://github.com/facebookresearch/sam2.git +uv pip install -r requirements.txt \ No newline at end of file diff --git a/labels.json b/labels.json new file mode 100644 index 0000000..77b1707 --- /dev/null +++ b/labels.json @@ -0,0 +1,4 @@ +{ + "PRODUCT": 1, + "SEPARATOR": 2 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9a22efe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +pyside6 +opencv-python +torch +torchvision +numpy +segment-anything +matplotlib +scikit-image +tqdm +wget +requests \ No newline at end of file diff --git a/run_ui.bat b/run_ui.bat new file mode 100644 index 0000000..000dde3 --- /dev/null +++ b/run_ui.bat @@ -0,0 +1,2 @@ +set PYTHONPATH=%PYTHONPATH%;%~dp0 +python segment_anything_ui/main_window.py diff --git a/run_ui.sh b/run_ui.sh new file mode 100644 index 0000000..9bbdd32 --- /dev/null +++ b/run_ui.sh @@ -0,0 +1,2 @@ +export PYTHONPATH=$PYTHONPATH:. +python3 segment_anything_ui/main_window.py \ No newline at end of file diff --git a/segment_anything_ui/__init__.py b/segment_anything_ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/segment_anything_ui/annotation_layout.py b/segment_anything_ui/annotation_layout.py new file mode 100644 index 0000000..314ec2a --- /dev/null +++ b/segment_anything_ui/annotation_layout.py @@ -0,0 +1,218 @@ +import enum +import json +import os +import numpy as np +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox + +from segment_anything_ui.draw_label import PaintType +from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation + + +class MergeState(enum.Enum): + PICKING = enum.auto() + MERGING = enum.auto() + + +class AnnotationLayout(QWidget): + + def __init__(self, parent, config) -> None: + super().__init__(parent) + self.config = config + self.zoom_flag = False + self.merge_state = MergeState.PICKING + self.layout = QVBoxLayout(self) + labels = self._load_labels(config) + self.layout.setAlignment(Qt.AlignTop) + self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]") + self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]") + self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]") + self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]") + self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]") + self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]") + self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]") + self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]") + self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]") + self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]") + self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]") + self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]") + self.label_picker = QListWidget() + self.label_picker.addItems(labels) + self.label_picker.setCurrentRow(0) + self.move_current_mask_background = QPushButton("Move Current Mask to Front") + self.remove_hidden_masks = QPushButton("Remove Hidden Masks") + self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than") + self.remove_hidden_masks_line = QLineEdit("0.5") + self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key) + self.add_point.setShortcut(config.key_mapping.ADD_POINT.key) + self.add_box.setShortcut(config.key_mapping.ADD_BOX.key) + self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key) + self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key) + self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key) + self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key) + self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key) + self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key) + self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key) + + self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings()) + for w in [ + self.add_point, + self.add_box, + self.annotate_all, + self.pick_mask, + self.pick_bounding_box, + self.merge_masks, + self.move_current_mask_background, + self.cancel_annotation, + self.delete_mask, + self.partial_annotation, + self.save_annotation, + self.manual_polygon, + self.label_picker, + self.zoom_rectangle, + self.annotation_settings, + self.remove_hidden_masks, + self.remove_hidden_masks_label, + self.remove_hidden_masks_line + ]: + self.layout.addWidget(w) + self.add_point.clicked.connect(self.on_add_point) + self.add_box.clicked.connect(self.on_add_box) + self.annotate_all.clicked.connect(self.on_annotate_all) + self.cancel_annotation.clicked.connect(self.on_cancel_annotation) + self.save_annotation.clicked.connect(self.on_save_annotation) + self.pick_mask.clicked.connect(self.on_pick_mask) + self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box) + self.manual_polygon.clicked.connect(self.on_manual_polygon) + self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks) + self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn) + self.merge_masks.clicked.connect(self.on_merge_masks) + self.partial_annotation.clicked.connect(self.on_partial_annotation) + self.delete_mask.clicked.connect(self.on_delete_mask) + self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle) + + def on_delete_mask(self): + if self.parent().image_label.paint_type == PaintType.MASK_PICKER: + self.parent().info_label.setText("Deleting mask!") + mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id) + self.parent().annotator.bounding_boxes.remove(mask_uid) + self.parent().annotator.masks.mask_id = -1 + self.parent().annotator.last_mask = None + self.parent().update(self.parent().annotator.merge_image_visualization()) + elif self.parent().image_label.paint_type == PaintType.BOX_PICKER: + self.parent().info_label.setText("Deleting bounding box!") + mask_uid = self.parent().annotator.bounding_boxes.remove_by_id( + self.parent().annotator.bounding_boxes.bounding_box_id) + if mask_uid is not None: + self.parent().annotator.masks.pop_by_uuid(mask_uid) + self.parent().annotator.bounding_boxes.bounding_box_id = -1 + self.parent().annotator.last_mask = None + self.parent().annotator.masks.mask_id = -1 + self.parent().update(self.parent().annotator.merge_image_visualization()) + else: + QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!") + + def on_partial_annotation(self): + self.parent().info_label.setText("Partial annotation!") + self.parent().annotator.pick_partial_mask() + self.parent().image_label.clear() + + @staticmethod + def _load_labels(config): + if not os.path.exists(config.label_file): + return ["default"] + with open(config.label_file, "r") as f: + labels = json.load(f) + MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default" + return labels + + def on_merge_masks(self): + self.parent().image_label.change_paint_type(PaintType.MASK_PICKER) + if self.merge_state == MergeState.PICKING: + self.parent().info_label.setText("Pick a mask to merge with!") + self.merge_state = MergeState.MERGING + self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy() + elif self.merge_state == MergeState.MERGING: + self.parent().info_label.setText("Merging masks!") + self.parent().annotator.merge_masks() + self.merge_state = MergeState.PICKING + + def on_move_current_mask_background_fn(self): + self.parent().info_label.setText("Moving current mask to background!") + self.parent().annotator.move_current_mask_to_background() + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def on_remove_hidden_masks(self): + self.parent().info_label.setText("Removing hidden masks!") + annotations = self.parent().annotator.masks + argmax_mask = self.parent().annotator.make_instance_mask() + limit_ratio = float(self.remove_hidden_masks_line.text()) + new_masks = [] + new_labels = [] + for idx, (mask, label) in enumerate(annotations): + num_pixels = np.sum(mask > 0) + num_visible = np.sum(argmax_mask == (idx + 1)) + ratio = num_visible / num_pixels + + if ratio > limit_ratio: + new_masks.append(mask) + new_labels.append(label) + + print("Removed ", len(annotations) - len(new_masks), " masks.") + self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels) + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def on_pick_mask(self): + self.parent().info_label.setText("Pick a mask to do modifications!") + self.parent().image_label.change_paint_type(PaintType.MASK_PICKER) + + def on_pick_bounding_box(self): + self.parent().info_label.setText("Pick a bounding box to do modifications!") + self.parent().image_label.change_paint_type(PaintType.BOX_PICKER) + + def on_manual_polygon(self): + # Sets emphasis on the button + self.manual_polygon.setProperty("active", True) + self.parent().image_label.change_paint_type(PaintType.POLYGON) + + def on_add_point(self): + self.parent().info_label.setText("Adding point annotation!") + self.parent().image_label.change_paint_type(PaintType.POINT) + + def on_add_box(self): + self.parent().info_label.setText("Adding box annotation!") + self.parent().image_label.change_paint_type(PaintType.BOX) + + def on_zoom_rectangle(self): + if self.zoom_flag: + self.parent().info_label.setText("Zooming rectangle OFF!") + self.parent().image_label.change_paint_type(PaintType.POINT) + self.parent().annotator.zoomed_bounding_box = None + self.parent().annotator.make_embedding() + self.parent().update(self.parent().annotator.merge_image_visualization()) + self.zoom_flag = False + else: + self.parent().info_label.setText("Pick Mask to zoom!") + self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]") + self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER) + self.zoom_flag = True + + def on_annotate_all(self): + self.parent().info_label.setText("Annotating all. This make take some time.") + self.parent().annotator.predict_all(self.annotation_settings.get_values()) + self.parent().update(self.parent().annotator.merge_image_visualization()) + self.parent().info_label.setText("Annotate all finished.") + + def on_cancel_annotation(self): + self.parent().image_label.clear() + self.parent().annotator.clear_last_masks() + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def on_save_annotation(self): + if self.parent().image_label.paint_type == PaintType.POLYGON: + self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask( + self.config.window_size[0], self.config.window_size[1]) + self.parent().annotator.save_mask(label=self.label_picker.currentItem().text()) + self.parent().update(self.parent().annotator.merge_image_visualization()) + self.parent().image_label.clear() + diff --git a/segment_anything_ui/annotator.py b/segment_anything_ui/annotator.py new file mode 100644 index 0000000..05fc74d --- /dev/null +++ b/segment_anything_ui/annotator.py @@ -0,0 +1,445 @@ +import dataclasses +from typing import Callable +import uuid + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit +from segment_anything import SamPredictor +from segment_anything.build_sam import Sam +from segment_anything_ui.model_builder import ( + get_predictor, get_mask_generator, SamPredictor) +try: + from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam +except (ImportError, ModuleNotFoundError): + class EfficientViTSamPredictor: + pass + + class EfficientViTSam: + pass + +from skimage.measure import regionprops +import torch + +from segment_anything_ui.utils.shapes import BoundingBox +from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box + +def get_cmap(n, name='hsv'): + '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct + RGB color; the keyword argument name must be a standard mpl colormap name.''' + try: + return plt.cm.get_cmap(name, n) + except: + return plt.get_cmap(name, n) + +def crop_image( + image, + box: BoundingBox | None = None, + image_shape: tuple[int, int] | None = None +): + if image_shape is None: + image_shape = image.shape[:2][::-1] + if box is None: + return cv2.resize(image, image_shape) + + if len(image.shape) == 2: + return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape) + return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape) + + +def insert_image(image, box: BoundingBox | None = None): + new_image = np.zeros_like(image) + if box is None: + new_image = image + else: + if len(image.shape) == 2: + new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize( + image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart))) + else: + new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize( + image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart))) + return new_image + + +@dataclasses.dataclass() +class AutomaticMaskGeneratorSettings: + points_per_side: int = 32 + pred_iou_thresh: float = 0.88 + stability_score_thresh: float = 0.95 + stability_score_offset: float = 1.0 + box_nms_thresh: float = 0.7 + crop_n_layers: int = 0 + crop_nms_thresh: float = 0.7 + + +class LabelValueParam(QWidget): + def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None): + super().__init__(parent) + self.layout = QVBoxLayout(self) + self.layout.setSpacing(0) + self.layout.setContentsMargins(0, 0, 0, 0) + self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter) + self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter) + self.layout.addWidget(self.label) + self.layout.addWidget(self.value) + self.converter = value_type_converter + + def get_value(self): + return self.converter(self.value.text()) + + +class CustomForm(QWidget): + + def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None: + super().__init__(parent) + self.layout = QVBoxLayout(self) + self.layout.setSpacing(0) + self.layout.setContentsMargins(0, 0, 0, 0) + self.widgets = [] + + for field in dataclasses.fields(automatic_mask_generator_settings): + widget = LabelValueParam(field.name, str(field.default), field.type) + self.widgets.append(widget) + self.layout.addWidget(widget) + + def get_values(self): + return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets}) + + +class BoundingBoxAnnotation: + def __init__(self) -> None: + self.bounding_boxes: list[BoundingBox] = [] + self.bounding_box_id: int = -1 + + def append(self, bounding_box: BoundingBox): + self.bounding_boxes.append(bounding_box) + + def find_closest_bounding_box(self, point: np.ndarray): + closest_bounding_box = None + closest_bounding_box_id = -1 + min_distance = float('inf') + for idx, bounding_box in enumerate(self.bounding_boxes): + distance = bounding_box.distance_to(point) + if distance < min_distance and bounding_box.contains(point): + min_distance = distance + closest_bounding_box = bounding_box + closest_bounding_box_id = idx + self.bounding_box_id = closest_bounding_box_id + return closest_bounding_box, closest_bounding_box_id + + def get_bounding_box(self, bounding_box_id: int): + return self.bounding_boxes[bounding_box_id] + + def get_current_bounding_box(self): + return self.bounding_boxes[self.bounding_box_id] + + def set_current_bounding_box(self, bounding_box: BoundingBox): + self.bounding_boxes[self.bounding_box_id] = bounding_box + + def remove(self, mask_uid: str): + bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None) + if bounding_box_id is None: + return + bounding_box = self.bounding_boxes.pop(bounding_box_id) + if self.bounding_box_id >= bounding_box_id: + self.bounding_box_id -= 1 + return bounding_box + + def remove_by_id(self, bounding_box_id: int): + mask_uid = self.bounding_boxes[bounding_box_id].mask_uid + self.remove(mask_uid) + return mask_uid + + def __len__(self): + return len(self.bounding_boxes) + + +class MasksAnnotation: + DEFAULT_LABEL = "default" + + def __init__(self) -> None: + self.masks = [] + self.label_map = {} + self.masks_uids: list[str] = [] + self.mask_id: int = -1 + + def add_mask(self, mask, label: str | None = None): + self.masks.append(mask) + self.masks_uids.append(str(uuid.uuid4())) + self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label + return self.masks_uids[-1] + + def add_label(self, mask_id: int, label: str): + self.label_map[mask_id] = label + + def get_mask(self, mask_id: int): + return self.masks[mask_id] + + def get_label(self, mask_id: int): + return self.label_map[mask_id] + + def get_current_mask(self): + return self.masks[self.mask_id] + + def set_current_mask(self, mask, label: str = None): + self.masks[self.mask_id] = mask + self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label + + def __getitem__(self, mask_id: int): + return self.get_mask(mask_id) + + def __setitem__(self, mask_id: int, value): + self.masks[mask_id] = value + + def __len__(self): + return len(self.masks) + + def __iter__(self): + return iter(zip(self.masks, self.label_map.values())) + + def __next__(self): + if self.mask_id >= len(self.masks): + raise StopIteration + return self.masks[self.mask_id] + + def append(self, mask, label: str | None = None): + return self.add_mask(mask, label) + + def pop_by_uuid(self, mask_uid: str): + mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None) + if mask_id is None: + return + return self.pop(mask_id) + + def pop(self, mask_id: int = -1): + _ = self.masks.pop(mask_id) + mask_uid = self.masks_uids.pop(mask_id) + self.label_map.pop(mask_id + 1) + new_label_map = {} + for index, value in enumerate(self.label_map.values()): + new_label_map[index + 1] = value + self.label_map = new_label_map + return mask_uid + + @classmethod + def from_masks(cls, masks, labels: list[str] | None = None): + annotation = cls() + if labels is None: + labels = [None] * len(masks) + for mask, label in zip(masks, labels): + annotation.append(mask, label) + return annotation + + +@dataclasses.dataclass() +class Annotator: + sam: Sam | EfficientViTSam | None = None + embedding: torch.Tensor | None = None + image: np.ndarray | None = None + masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation) + bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation) + predictor: SamPredictor | EfficientViTSamPredictor | None = None + visualization: np.ndarray | None = None + last_mask: np.ndarray | None = None + partial_mask: np.ndarray | None = None + merged_mask: np.ndarray | None = None + parent: QWidget | None = None + cmap: plt.cm = None + original_image: np.ndarray | None = None + zoomed_bounding_box: BoundingBox | None = None + + def __post_init__(self): + self.MAX_MASKS = 10 + self.cmap = get_cmap(self.MAX_MASKS) + + def set_image(self, image: np.ndarray): + self.image = image + return self + + def make_embedding(self): + if self.sam is None: + return + self.predictor = get_predictor(self.sam) + self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box)) + + def predict_all(self, settings: AutomaticMaskGeneratorSettings): + generator = get_mask_generator( + sam=self.sam, + **dataclasses.asdict(settings) + ) + masks = generator.generate(self.image) + masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks] + label = self.parent.annotation_layout.label_picker.currentItem().text() + self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks)) + self.cmap = get_cmap(len(self.masks)) + + def make_prediction(self, annotation: dict): + masks, scores, logits = self.predictor.predict( + point_coords=annotation["points"], + point_labels=annotation["labels"], + box=annotation["bounding_boxes"], + multimask_output=False + ) + mask = masks[0] + self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255 + + def pick_partial_mask(self): + if self.partial_mask is None: + self.partial_mask = self.last_mask.copy() + else: + self.partial_mask = np.maximum(self.last_mask, self.partial_mask) + self.last_mask = None + + def move_current_mask_to_background(self): + self.masks.set_current_mask(self.masks.get_current_mask() * 0.5) + + def merge_masks(self): + new_mask = np.bitwise_or(self.last_mask, self.merged_mask) + self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text()) + self.merged_mask = None + + def visualize_last_mask(self, label: str | None = None): + last_mask = np.zeros_like(self.image) + last_mask[:, :, 1] = self.last_mask + if self.partial_mask is not None: + last_mask[:, :, 0] = self.partial_mask + if self.merged_mask is not None: + last_mask[:, :, 2] = self.merged_mask + if label is not None: + props = regionprops(self.last_mask)[0] + cv2.putText( + last_mask, + label, + (int(props.centroid[1]), int(props.centroid[0])), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + [255, 255, 255], + 2 + ) + if self.is_show_bounding_boxes: + last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label) + cv2.rectangle( + last_mask, + (int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])), + (int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])), + (0, 255, 0), + 2 + ) + cv2.putText( + last_mask, + label, + (int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + [255, 255, 255], + 2 + ) + visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(), + 0.8, last_mask, 0.5, 0) + self.parent.update(crop_image(visualization, self.zoomed_bounding_box)) + + def _visualize_mask(self) -> tuple: + mask_argmax = self.make_instance_mask() + visualization = np.zeros_like(self.image) + border = np.zeros(self.image.shape[:2], dtype=np.uint8) + for i in range(1, np.amax(mask_argmax) + 1): + color = self.cmap(i) + single_mask = np.zeros_like(mask_argmax) + single_mask[mask_argmax == i] = 1 + visualization[mask_argmax == i, :] = np.array(color[:3]) * 255 + border += single_mask - cv2.erode( + single_mask, np.ones((3, 3), np.uint8), iterations=1) + label = self.masks.get_label(i) + single_mask_center = np.mean(np.where(single_mask == 1), axis=1) + if np.isnan(single_mask_center).any(): + continue + if self.parent.settings.is_show_text(): + cv2.putText( + visualization, + label, + (int(single_mask_center[1]), int(single_mask_center[0])), + cv2.FONT_HERSHEY_PLAIN, + 0.5, + [255, 255, 255], + 1 + ) + if self.is_show_bounding_boxes: + bounding_boxes = self.get_bounding_boxes() + for idx, bounding_box in enumerate(bounding_boxes): + cv2.rectangle( + visualization, + (int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])), + (int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])), + (0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0), + 2 + ) + cv2.putText( + visualization, + bounding_box.label, + (int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + [255, 255, 255], + 2 + ) + border = (border == 0).astype(np.uint8) + return visualization, border + + def has_annotations(self): + return len(self.masks) > 0 + + def make_instance_mask(self): + background = np.zeros_like(self.masks[0]) + 1 + mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8) + return mask_argmax + + def get_bounding_boxes(self): + return get_bounding_boxes(self.masks.masks, self.masks.label_map.values()) + + def merge_image_visualization(self): + image = self.image.copy() + if not len(self.masks): + return crop_image(image, self.zoomed_bounding_box) + visualization, border = self._visualize_mask() + self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis] + return crop_image(self.visualization, self.zoomed_bounding_box) + + def remove_last_mask(self): + mask_id = len(self.masks) + self.masks.pop(mask_id) + self.bounding_boxes.remove(mask_id) + + def make_labels(self): + return self.masks.label_map + + def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL): + if self.partial_mask is not None: + last_mask = self.partial_mask + self.partial_mask = None + else: + last_mask = self.last_mask + mask_uid = self.masks.add_mask(last_mask, label=label) + corresponding_bounding_box = get_mask_bounding_box(last_mask, label) + corresponding_bounding_box.mask_uid = mask_uid + self.bounding_boxes.append(corresponding_bounding_box) + if len(self.masks) >= self.MAX_MASKS: + self.MAX_MASKS += 10 + self.cmap = get_cmap(self.MAX_MASKS) + + @property + def is_show_bounding_boxes(self): + return self.parent.settings.is_show_bounding_boxes() + + def clear_last_masks(self): + self.last_mask = None + self.partial_mask = None + self.visualization = None + + def clear(self): + self.last_mask = None + self.visualization = None + self.masks = MasksAnnotation() + self.bounding_boxes = BoundingBoxAnnotation() + self.partial_mask = None diff --git a/segment_anything_ui/config.py b/segment_anything_ui/config.py new file mode 100644 index 0000000..9559e8b --- /dev/null +++ b/segment_anything_ui/config.py @@ -0,0 +1,143 @@ +import dataclasses +import os +from typing import Literal + +from PySide6.QtCore import Qt +import requests +try: + from tqdm import tqdm + import wget +except ImportError: + print("Tqdm and wget not found. Install with pip install tqdm wget") + tqdm = None + wget = None + + +@dataclasses.dataclass(frozen=True) +class Keymap: + key: Qt.Key | str + name: str + + +class ProgressBar: + def __init__(self): + self.progress_bar = None + + def __call__(self, current_bytes, total_bytes, width): + current_mb = round(current_bytes / 1024 ** 2, 1) + total_mb = round(total_bytes / 1024 ** 2, 1) + if self.progress_bar is None: + self.progress_bar = tqdm(total=total_mb, desc="MB") + delta_mb = current_mb - self.progress_bar.n + self.progress_bar.update(delta_mb) + + +@dataclasses.dataclass +class KeyBindings: + ADD_POINT: Keymap = Keymap(Qt.Key.Key_W, "W") + ADD_BOX: Keymap = Keymap(Qt.Key.Key_Q, "Q") + ANNOTATE_ALL: Keymap = Keymap(Qt.Key.Key_Return, "Enter") + MANUAL_POLYGON: Keymap = Keymap(Qt.Key.Key_R, "R") + CANCEL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_C, "C") + SAVE_ANNOTATION: Keymap = Keymap(Qt.Key.Key_S, "S") + PICK_MASK: Keymap = Keymap(Qt.Key.Key_X, "X") + PICK_BOUNDING_BOX: Keymap = Keymap(Qt.Key.Key_B, "B") + MERGE_MASK: Keymap = Keymap(Qt.Key.Key_Z, "Z") + DELETE_MASK: Keymap = Keymap(Qt.Key.Key_V, "V") + PARTIAL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_D, "D") + SAVE_BOUNDING_BOXES: Keymap = Keymap("Ctrl+B", "Ctrl+B") + NEXT_FILE: Keymap = Keymap(Qt.Key.Key_F, "F") + PREVIOUS_FILE: Keymap = Keymap(Qt.Key.Key_G, "G") + SAVE_MASK: Keymap = Keymap("Ctrl+S", "Ctrl+S") + PRECOMPUTE: Keymap = Keymap(Qt.Key.Key_P, "P") + ZOOM_RECTANGLE: Keymap = Keymap(Qt.Key.Key_E, "E") + + +@dataclasses.dataclass +class Config: + default_weights: Literal[ + "sam_vit_b_01ec64.pth", + "sam_vit_h_4b8939.pth", + "sam_vit_l_0b3195.pth", + "xl0.pt", + "xl1.pt", + "sam_hq_vit_b.pth", + "sam_hq_vit_l.pth", + "sam_hq_vit_h.pth", + "sam_hq_vit_tiny.pth", + "sam2.1_hiera_t.pth", + "sam2.1_hiera_l.pth", + "sam2.1_hiera_b+.pth", + "sam2.1_hiera_s.pth", + ] = "sam_vit_h_4b8939.pth" + download_weights_if_not_available: bool = True + label_file: str = "labels.json" + window_size: tuple[int, int] | int = (1920, 1080) + key_mapping: KeyBindings = dataclasses.field(default_factory=KeyBindings) + weights_paths: dict[str, str] = dataclasses.field(default_factory=lambda: { + "l2": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "xl0": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt", + "xl1": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt", + "hq_vit_b": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth", + "hq_vit_l": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth", + "hq_vit_h": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", + "hq_vit_tiny": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth", + "sam2.1_hiera_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "sam2.1_hiera_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "sam2.1_hiera_b+": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "sam2.1_hiera_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + }) + + def __post_init__(self): + if isinstance(self.window_size, int): + self.window_size = (self.window_size, self.window_size) + if self.download_weights_if_not_available: + self.download_weights() + + def get_sam_model_name(self): + if "l2" in self.default_weights: + return "l2" + if "sam_vit_b" in self.default_weights: + return "vit_b" + if "sam_vit_h" in self.default_weights: + return "vit_h" + if "sam_vit_l" in self.default_weights: + return "vit_l" + if "xl0" in self.default_weights: + return "xl0" + if "xl1" in self.default_weights: + return "xl1" + if "hq_vit_b" in self.default_weights: + return "hq_vit_b" + if "hq_vit_l" in self.default_weights: + return "hq_vit_l" + if "hq_vit_h" in self.default_weights: + return "hq_vit_h" + if "hq_vit_tiny" in self.default_weights: + return "hq_vit_tiny" + if "sam2.1_hiera_t" in self.default_weights: + return "sam2.1_hiera_t" + if "sam2.1_hiera_l" in self.default_weights: + return "sam2.1_hiera_l" + if "sam2.1_hiera_b+" in self.default_weights: + return "sam2.1_hiera_b+" + if "sam2.1_hiera_s" in self.default_weights: + return "sam2.1_hiera_s" + raise ValueError("Unknown model name") + def download_weights(self): + if not os.path.exists(self.default_weights): + try: + print(f"Downloading weights for model {self.get_sam_model_name()}") + wget.download(self.weights_paths[self.get_sam_model_name()], self.default_weights, bar=ProgressBar()) + print(f"Downloaded weights to {self.default_weights}") + except Exception as e: + print(f"Error downloading weights: {e}. Trying with requests.") + model_name = self.get_sam_model_name() + print(f"Downloading weights for model {model_name}") + file = requests.get(self.weights_paths[model_name]) + with open(self.default_weights, "wb") as f: + f.write(file.content) + print(f"Downloaded weights to {self.default_weights}") diff --git a/segment_anything_ui/draw_label.py b/segment_anything_ui/draw_label.py new file mode 100644 index 0000000..353f2eb --- /dev/null +++ b/segment_anything_ui/draw_label.py @@ -0,0 +1,268 @@ +import copy +from enum import Enum + +import cv2 +import numpy as np +import PySide6 +from PySide6 import QtCore, QtWidgets +from PySide6.QtGui import QPainter, QPen + +from segment_anything_ui.config import Config +from segment_anything_ui.utils.shapes import BoundingBox, Polygon + + +class PaintType(Enum): + POINT = 0 + BOX = 1 + MASK = 2 + POLYGON = 3 + MASK_PICKER = 4 + ZOOM_PICKER = 5 + BOX_PICKER = 6 + + +class MaskIdPicker: + + def __init__(self, length) -> None: + self.counter = 0 + self.length = length + + def increment(self): + self.counter = (self.counter + 1) % self.length + + def pick(self, ids): + print("Length of ids: ", len(ids), " counter: ", self.counter, " ids: ", ids) + if len(ids) <= self.counter: + self.counter = 0 + return_id = ids[self.counter] + self.increment() + return return_id + + +class DrawLabel(QtWidgets.QLabel): + + def __init__(self, parent=None): + super().__init__(parent) + self.positive_points = [] + self.negative_points = [] + self.bounding_box = None + self.partial_box = BoundingBox(0, 0) + self._paint_type = PaintType.POINT + self.polygon = Polygon() + self.mask_enum: MaskIdPicker = MaskIdPicker(3) + self.config = Config() + self.setFocusPolicy(QtCore.Qt.StrongFocus) + self._zoom_center = (0, 0) + self._zoom_factor = 1.0 + self._zoom_bounding_box: BoundingBox | None = None + + def paintEvent(self, paint_event): + painter = QPainter(self) + painter.drawPixmap(self.rect(), self.pixmap()) + pen_positive = self._get_pen(QtCore.Qt.green, 5) + pen_negative = self._get_pen(QtCore.Qt.red, 5) + pen_partial = self._get_pen(QtCore.Qt.yellow, 1) + pen_box = self._get_pen(QtCore.Qt.green, 1) + painter.setRenderHint(QPainter.Antialiasing, False) + + painter.setPen(pen_box) + + if self.bounding_box is not None and self.bounding_box.xend != -1 and self.bounding_box.yend != -1: + painter.drawRect( + self.bounding_box.xstart, + self.bounding_box.ystart, + self.bounding_box.xend - self.bounding_box.xstart, + self.bounding_box.yend - self.bounding_box.ystart + ) + + painter.setPen(pen_partial) + painter.drawRect(self.partial_box.xstart, self.partial_box.ystart, + self.partial_box.xend - self.partial_box.xstart, + self.partial_box.yend - self.partial_box.ystart) + + painter.setPen(pen_positive) + for pos in self.positive_points: + painter.drawPoint(pos) + + painter.setPen(pen_negative) + painter.setRenderHint(QPainter.Antialiasing, False) + for pos in self.negative_points: + painter.drawPoint(pos) + + if self.polygon.is_plotable(): + painter.setPen(pen_box) + painter.setRenderHint(QPainter.Antialiasing, True) + painter.drawPolygon(self.polygon.to_qpolygon()) + # self.update() + + def _get_pen(self, color=QtCore.Qt.red, width=5): + pen = QPen() + pen.setWidth(width) + pen.setColor(color) + return pen + + @property + def paint_type(self): + return self._paint_type + + def change_paint_type(self, paint_type: PaintType): + print(f"Changing paint type to {paint_type}") + self._paint_type = paint_type + + def mouseMoveEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None: + if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]: + self.partial_box = copy.deepcopy(self.bounding_box) + self.partial_box.xend = ev.pos().x() + self.partial_box.yend = ev.pos().y() + self.update() + + if self._paint_type == PaintType.POINT: + point = ev.pos() + if ev.buttons() == QtCore.Qt.LeftButton: + self._move_update(None, point) + elif ev.buttons() == QtCore.Qt.RightButton: + self._move_update(point, None) + else: + pass + self.update() + + def _move_update(self, temporary_point_negative, temporary_point_positive): + annotations = self.get_annotations(temporary_point_positive, temporary_point_negative) + self.parent().annotator.make_prediction(annotations) + self.parent().annotator.visualize_last_mask() + + def mouseReleaseEvent(self, cursor_event): + if self._paint_type == PaintType.POINT: + if cursor_event.button() == QtCore.Qt.LeftButton: + self.positive_points.append(cursor_event.pos()) + print(self.size()) + elif cursor_event.button() == QtCore.Qt.RightButton: + self.negative_points.append(cursor_event.pos()) + # self.chosen_points.append(self.mapFromGlobal(QtGui.QCursor.pos())) + elif self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]: + if cursor_event.button() == QtCore.Qt.LeftButton: + self.bounding_box.xend = cursor_event.pos().x() + self.bounding_box.yend = cursor_event.pos().y() + self.partial_box = BoundingBox(-1, -1, -1, -1) + + if not self._paint_type == PaintType.MASK_PICKER and \ + not self._paint_type == PaintType.ZOOM_PICKER and \ + not self._paint_type == PaintType.POLYGON and \ + not self._paint_type == PaintType.BOX_PICKER: + self.parent().annotator.make_prediction(self.get_annotations()) + self.parent().annotator.visualize_last_mask() + + if self._paint_type == PaintType.ZOOM_PICKER: + self.parent().annotator.zoomed_bounding_box = self.bounding_box.scale(*self._get_scale()).to_int() + self.bounding_box = None + self.parent().annotator.make_embedding() + self.parent().update(self.parent().annotator.merge_image_visualization()) + self._paint_type = PaintType.POINT + + self.update() + + def mousePressEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None: + if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER] and ev.button() == QtCore.Qt.LeftButton: + self.bounding_box = BoundingBox(xstart=ev.pos().x(), ystart=ev.pos().y()) + + if self._paint_type == PaintType.POLYGON and ev.button() == QtCore.Qt.LeftButton: + self.polygon.points.append([ev.pos().x(), ev.pos().y()]) + + if self._paint_type == PaintType.MASK_PICKER and ev.button() == QtCore.Qt.LeftButton: + size = self.size() + point = [ + int(ev.pos().x() / size.width() * self.config.window_size[0]), + int(ev.pos().y() / size.height() * self.config.window_size[1])] + masks = np.array(self.parent().annotator.masks.masks) + mask_ids = np.where(masks[:, point[1], point[0]])[0] + print("Picking mask at point: {}".format(point)) + if not(len(mask_ids)): + print("No mask found") + mask_id = -1 + local_mask = np.zeros((masks.shape[1], masks.shape[2])) + label = None + else: + mask_id = self.mask_enum.pick(mask_ids) + local_mask = self.parent().annotator.masks.get_mask(mask_id) + label = self.parent().annotator.masks.get_label(mask_id + 1) + self.parent().annotator.masks.mask_id = mask_id + self.parent().annotator.last_mask = local_mask + self.parent().annotator.visualize_last_mask(label) + + if self._paint_type == PaintType.BOX_PICKER and ev.button() == QtCore.Qt.LeftButton: + size = self.size() + point = [ + float(ev.pos().x() / size.width()), + float(ev.pos().y() / size.height())] + bounding_box, bounding_box_id = self.parent().annotator.bounding_boxes.find_closest_bounding_box(point) + if bounding_box is None: + print("No bounding box found") + else: + self.parent().annotator.bounding_boxes.bounding_box_id = bounding_box_id + print(f"Bounding box: {bounding_box}") + print(f"Bounding box id: {bounding_box_id}") + self.parent().update(self.parent().annotator.merge_image_visualization()) + + + if self._paint_type == PaintType.POINT: + point = ev.pos() + if ev.button() == QtCore.Qt.LeftButton: + self._move_update(None, point) + if ev.button() == QtCore.Qt.RightButton: + self._move_update(point, None) + self.update() + + def zoom_to_rectangle(self, xstart, ystart, xend, yend): + picked_image = self.parent().annotator.image[ystart:yend, xstart:xend, :] + self.parent().annotator.image = cv2.resize(picked_image, (self.config.window_size[0], self.config.window_size[1])) + self.update() + + def keyPressEvent(self, ev: PySide6.QtGui.QKeyEvent) -> None: + print(ev.key()) + if self._paint_type == PaintType.MASK_PICKER and ev.key() == QtCore.Qt.Key.Key_D and len(self.parent().annotator.masks): + print("Deleting mask") + self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id) + self.parent().annotator.masks.mask_id = -1 + self.parent().annotator.last_mask = None + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def _get_scale(self): + return self.config.window_size[0] / self.size().width(), self.config.window_size[1] / self.size().height() + + def get_annotations( + self, + temporary_point_positive: PySide6.QtCore.QPoint | None = None, + temporary_point_negative: PySide6.QtCore.QPoint | None = None + ): + sx, sy = self._get_scale() + positive_points = [( + p.x() * sx, + p.y() * sy) for p in self.positive_points] + negative_points = [( + p.x() * sx, + p.y() * sy) for p in self.negative_points] + + if temporary_point_positive: + positive_points += [(temporary_point_positive.x() * sx, temporary_point_positive.y() * sy)] + if temporary_point_negative: + negative_points += [(temporary_point_negative.x() * sx, temporary_point_negative.y() * sy)] + + positive_points = np.array(positive_points).reshape(-1, 2) + negative_points = np.array(negative_points).reshape(-1, 2) + labels = np.array([1, ] * len(positive_points) + [0, ] * len(negative_points)) + print(f"Positive points: {positive_points}") + print(f"Negative points: {negative_points}") + print(f"Labels: {labels}") + return { + "points": np.concatenate([positive_points, negative_points], axis=0), + "labels": labels, + "bounding_boxes": self.bounding_box.scale(sx, sy).to_numpy() if self.bounding_box else None + } + + def clear(self): + self.positive_points = [] + self.negative_points = [] + self.bounding_box = None + self.partial_box = BoundingBox(0, 0, 0, 0) + self.polygon = Polygon() + self.update() diff --git a/segment_anything_ui/image_pixmap.py b/segment_anything_ui/image_pixmap.py new file mode 100644 index 0000000..90bf6f0 --- /dev/null +++ b/segment_anything_ui/image_pixmap.py @@ -0,0 +1,14 @@ +from PySide6.QtGui import QImage, QPixmap, QPainter, QPen +from PySide6.QtCore import Qt + + +class ImagePixmap(QPixmap): + def __init__(self): + super().__init__() + + def set_image(self, image): + if image.dtype == "uint8": + image = image.astype("float32") / 255.0 + image = (image * 255).astype("uint8") + image = QImage(image.data, image.shape[1], image.shape[0], QImage.Format_RGB888) + self.convertFromImage(image) diff --git a/segment_anything_ui/main_window.py b/segment_anything_ui/main_window.py new file mode 100644 index 0000000..715c3a8 --- /dev/null +++ b/segment_anything_ui/main_window.py @@ -0,0 +1,81 @@ +import logging +import sys + +import cv2 +import numpy as np +import torch +from PySide6.QtWidgets import (QApplication, QGridLayout, QLabel, + QMessageBox, QWidget) +from PySide6.QtCore import Qt + +from segment_anything_ui.annotator import Annotator +from segment_anything_ui.annotation_layout import AnnotationLayout +from segment_anything_ui.config import Config +from segment_anything_ui.draw_label import DrawLabel +from segment_anything_ui.image_pixmap import ImagePixmap +from segment_anything_ui.model_builder import build_model +from segment_anything_ui.settings_layout import SettingsLayout + + +class SegmentAnythingUI(QWidget): + + def __init__(self, config) -> None: + super().__init__() + self.config: Config = config + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.setWindowTitle("Segment Anything UI") + self.setWindowState(Qt.WindowState.WindowMaximized) + # self.setGeometry(100, 100, 800, 600) + self.layout = QGridLayout(self) + self.image_label = DrawLabel(self) + self.settings = SettingsLayout(self, config=self.config) + self.info_label = QLabel("Information about running process.") + self.sam = self.init_sam() + self.annotator = Annotator(sam=self.sam, parent=self) + self.annotation_layout = AnnotationLayout(self, config=self.config) + self.layout.addWidget(self.annotation_layout, 0, 0, 1, 1, Qt.AlignCenter) + self.layout.addWidget(self.image_label, 0, 1, 1, 1, Qt.AlignCenter) + self.layout.addWidget(self.settings, 0, 3, 1, 1, Qt.AlignCenter) + self.layout.addWidget(self.info_label, 1, 1, Qt.AlignBottom) + + self.set_image(np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)) + self.show() + + def set_image(self, image: np.ndarray, clear_annotations: bool = True): + self.annotator.set_image(image).make_embedding() + if clear_annotations: + self.annotator.clear() + self.update(image) + + def update(self, image: np.ndarray): + image = cv2.resize(image, self.config.window_size) + pixmap = ImagePixmap() + pixmap.set_image(image) + print("Updating image") + self.image_label.setPixmap(pixmap) + + def init_sam(self): + try: + checkpoint_path = str(self.settings.checkpoint_path.text()) + sam = build_model(self.config.get_sam_model_name(), checkpoint_path, self.device) + except Exception as e: + logging.getLogger().exception(f"Error loading model: {e}") + QMessageBox.critical(self, "Error", "Could not load model") + return None + return sam + + def get_mask(self): + return self.annotator.make_instance_mask() + + def get_labels(self): + return self.annotator.make_labels() + + def get_bounding_boxes(self): + return self.annotator.get_bounding_boxes() + + +if __name__ == '__main__': + + app = QApplication(sys.argv) + ex = SegmentAnythingUI(Config()) + sys.exit(app.exec()) \ No newline at end of file diff --git a/segment_anything_ui/model_builder.py b/segment_anything_ui/model_builder.py new file mode 100644 index 0000000..19ffd7e --- /dev/null +++ b/segment_anything_ui/model_builder.py @@ -0,0 +1,129 @@ +import os +from PySide6.QtWidgets import QMessageBox + +try: + from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam + from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator + IS_EFFICIENT_VIT_AVAILABLE = True +except (ModuleNotFoundError, ImportError) as e: + import logging + logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .") + IS_EFFICIENT_VIT_AVAILABLE = False + +try: + from segment_anything_hq import sam_model_registry as sam_hq_model_registry + from segment_anything_hq import SamPredictor as SamPredictorHQ + from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq + from segment_anything_hq.build_sam import Sam as SamHQ + IS_SAM_HQ_AVAILABLE = True + _SAM_HQ_MODEL_REGISTRY = { + "hq_vit_b": "vit_b", + "hq_vit_l": "vit_l", + "hq_vit_h": "vit_h", + "hq_vit_tiny": "vit_tiny", + } +except (ModuleNotFoundError, ImportError) as e: + import logging + logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .") + IS_SAM_HQ_AVAILABLE = False + _SAM_HQ_MODEL_REGISTRY = {} + +try: + from sam2.build_sam import build_sam2 + from sam2.modeling.sam2_base import SAM2Base + from sam2.sam2_image_predictor import SAM2ImagePredictor + from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + IS_SAM2_AVAILABLE = True + from hydra.core.global_hydra import GlobalHydra + from hydra import initialize + + # Reset Hydra's global configuration + if GlobalHydra.instance().is_initialized(): + GlobalHydra.instance().clear() + + _SAM2_MODEL_REGISTRY = { + "sam2.1_hiera_t": "sam2.1_hiera_t.yaml", + "sam2.1_hiera_l": "sam2.1_hiera_l.yaml", + "sam2.1_hiera_b": "sam2.1_hiera_b+.yaml", + "sam2.1_hiera_s": "sam2.1_hiera_s.yaml", + } + +except (ModuleNotFoundError, ImportError) as e: + import logging + logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .") + IS_SAM2_AVAILABLE = False + _SAM2_MODEL_REGISTRY = {} + +from segment_anything import sam_model_registry +from segment_anything import SamPredictor, automatic_mask_generator +from segment_anything.build_sam import Sam + + +def build_model(model_name: str, checkpoint_path: str, device: str): + match model_name: + case "xl0" | "xl1": + if not IS_EFFICIENT_VIT_AVAILABLE: + raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .") + efficientvit_sam = create_efficientvit_sam_model( + name=model_name, weight_url=checkpoint_path, + ) + return efficientvit_sam.to(device).eval() + + case "vit_b" | "vit_l" | "vit_h": + sam = sam_model_registry[model_name]( + checkpoint=checkpoint_path) + sam.eval() + return sam.to(device) + + case "hq_vit_b" | "hq_vit_l" | "hq_vit_h": + if not IS_SAM_HQ_AVAILABLE: + QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .") + raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .") + sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]]( + checkpoint=checkpoint_path) + sam.eval() + return sam.to(device) + case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s": + if not IS_SAM2_AVAILABLE: + QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .") + raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .") + + with initialize(version_base=None, config_path="sam2_configs"): + sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device) + sam.eval() + return sam + case _: + raise ValueError(f"Model {model_name} not supported") + + +def get_predictor(sam): + if isinstance(sam, Sam): + return SamPredictor(sam) + elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam): + return EfficientViTSamPredictor(sam) + + elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ): + return SamPredictorHQ(sam) + + elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base): + return SAM2ImagePredictor(sam) + else: + raise ValueError("Model is not an EfficientViTSam or Sam") + + +def get_mask_generator(sam, **kwargs): + if isinstance(sam, Sam): + return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs) + + elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ): + return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs) + + elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam): + return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs) + + elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base): + return SAM2AutomaticMaskGenerator(model=sam) + + else: + raise ValueError("Model is not an EfficientViTSam or Sam") \ No newline at end of file diff --git a/segment_anything_ui/modeling/__init__.py b/segment_anything_ui/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/segment_anything_ui/modeling/efficientvit/sam_model_zoo.py b/segment_anything_ui/modeling/efficientvit/sam_model_zoo.py new file mode 100644 index 0000000..5e90e51 --- /dev/null +++ b/segment_anything_ui/modeling/efficientvit/sam_model_zoo.py @@ -0,0 +1,45 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from segment_anything_ui.modeling.efficientvit.models.efficientvit import ( + EfficientViTSam, + efficientvit_sam_l0, + efficientvit_sam_l1, + efficientvit_sam_l2, +) +from segment_anything_ui.modeling.efficientvit.models.nn.norm import set_norm_eps +from segment_anything_ui.modeling.efficientvit.models.utils import load_state_dict_from_file + +__all__ = ["create_sam_model"] + + +REGISTERED_SAM_MODEL: dict[str, str] = { + "l0": "assets/checkpoints/sam/l0.pt", + "l1": "assets/checkpoints/sam/l1.pt", + "l2": "assets/checkpoints/sam/l2.pt", +} + + +def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam: + model_dict = { + "l0": efficientvit_sam_l0, + "l1": efficientvit_sam_l1, + "l2": efficientvit_sam_l2, + } + + model_id = name.split("-")[0] + if model_id not in model_dict: + raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") + else: + model = model_dict[model_id](**kwargs) + set_norm_eps(model, 1e-6) + + if pretrained: + weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None) + if weight_url is None: + raise ValueError(f"Do not find the pretrained weight of {name}.") + else: + weight = load_state_dict_from_file(weight_url) + model.load_state_dict(weight) + return model diff --git a/segment_anything_ui/modeling/storable_sam.py b/segment_anything_ui/modeling/storable_sam.py new file mode 100644 index 0000000..49f8655 --- /dev/null +++ b/segment_anything_ui/modeling/storable_sam.py @@ -0,0 +1,30 @@ +from safetensors import safe_open +from segment_anything.modeling import Sam +import torch.nn as nn + + +class ModifiedImageEncoder(nn.Module): + + def __init__(self, image_encoder, saved_path: str | None = None): + super().__init__() + self.image_encoder = image_encoder + if saved_path is not None: + self.embeddings = safe_open(saved_path) + else: + self.embeddings = None + + def forward(self, x): + return self.image_encoder(x) if self.embeddings is None else self.embeddings + + +class StorableSam: + + def __init__(self, sam): + self.sam = sam + self.image_encoder = sam.image_encoder + + def transform(self, saved_path): + self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path) + + def precompute(self, image): + return self.image_encoder(image) diff --git a/segment_anything_ui/sam2_configs/sam2.1_hiera_b+.yaml b/segment_anything_ui/sam2_configs/sam2.1_hiera_b+.yaml new file mode 100644 index 0000000..d7172f9 --- /dev/null +++ b/segment_anything_ui/sam2_configs/sam2.1_hiera_b+.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/segment_anything_ui/sam2_configs/sam2.1_hiera_l.yaml b/segment_anything_ui/sam2_configs/sam2.1_hiera_l.yaml new file mode 100644 index 0000000..23073ea --- /dev/null +++ b/segment_anything_ui/sam2_configs/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/segment_anything_ui/sam2_configs/sam2.1_hiera_s.yaml b/segment_anything_ui/sam2_configs/sam2.1_hiera_s.yaml new file mode 100644 index 0000000..fd8d404 --- /dev/null +++ b/segment_anything_ui/sam2_configs/sam2.1_hiera_s.yaml @@ -0,0 +1,119 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/segment_anything_ui/sam2_configs/sam2.1_hiera_t.yaml b/segment_anything_ui/sam2_configs/sam2.1_hiera_t.yaml new file mode 100644 index 0000000..e762aec --- /dev/null +++ b/segment_anything_ui/sam2_configs/sam2.1_hiera_t.yaml @@ -0,0 +1,121 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/segment_anything_ui/saver.py b/segment_anything_ui/saver.py new file mode 100644 index 0000000..164956c --- /dev/null +++ b/segment_anything_ui/saver.py @@ -0,0 +1,13 @@ +import os +from typing import Any +import torch + + +class Saver: + + def __init__(self, path: str) -> None: + self.path = path + + def __call__(self, basename, mask_annotation) -> Any: + save_path = os.path.join(self.path, basename) + # TODO: finish this \ No newline at end of file diff --git a/segment_anything_ui/segment_anything_control.py b/segment_anything_ui/segment_anything_control.py new file mode 100644 index 0000000..e69de29 diff --git a/segment_anything_ui/settings_layout.py b/segment_anything_ui/settings_layout.py new file mode 100644 index 0000000..490c34d --- /dev/null +++ b/segment_anything_ui/settings_layout.py @@ -0,0 +1,228 @@ +import json +import os +import pathlib +import random + +import cv2 +import numpy as np +from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox + +from segment_anything_ui.annotator import MasksAnnotation +from segment_anything_ui.config import Config +from segment_anything_ui.utils.bounding_boxes import BoundingBox + + +class FilesHolder: + def __init__(self): + self.files = [] + self.position = 0 + + def add_files(self, files): + self.files.extend(files) + + def get_next(self): + self.position += 1 + if self.position >= len(self.files): + self.position = 0 + return self.files[self.position] + + def get_previous(self): + self.position -= 1 + + if self.position < 0: + self.position = len(self.files) - 1 + return self.files[self.position] + + +class SettingsLayout(QWidget): + MASK_EXTENSION = "_mask.png" + LABELS_EXTENSION = "_labels.json" + BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json" + + def __init__(self, parent, config: Config) -> None: + super().__init__(parent) + self.config = config + self.actual_file: str = "" + self.actual_shape = self.config.window_size + self.layout = QVBoxLayout(self) + self.open_files = QPushButton("Open Files") + self.open_files.clicked.connect(self.on_open_files) + self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]") + self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]") + self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key) + self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]") + self.save_mask.clicked.connect(self.on_save_mask) + self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key) + self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]") + self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes) + self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key) + self.next_file.clicked.connect(self.on_next_file) + self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key) + self.previous_file.clicked.connect(self.on_previous_file) + self.checkpoint_path_label = QLabel(self, text="Checkpoint Path") + self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights) + self.precompute_button = QPushButton("Precompute all embeddings") + self.precompute_button.clicked.connect(self.on_precompute) + self.delete_existing_annotation = QPushButton("Delete existing annotation") + self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation) + self.show_image = QPushButton("Show Image") + self.show_visualization = QPushButton("Show Visualization") + self.show_bounding_boxes = QCheckBox("Show Bounding Boxes") + self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes) + self.show_image.clicked.connect(self.on_show_image) + self.show_visualization.clicked.connect(self.on_show_visualization) + self.show_text = QCheckBox("Show Text") + self.show_text.clicked.connect(self.on_show_text) + self.tag_text_field = QLineEdit(self) + self.tag_text_field.setPlaceholderText("Comma separated image tags") + self.layout.addWidget(self.open_files) + self.layout.addWidget(self.next_file) + self.layout.addWidget(self.previous_file) + self.layout.addWidget(self.save_mask) + self.layout.addWidget(self.save_bounding_boxes) + self.layout.addWidget(self.delete_existing_annotation) + self.layout.addWidget(self.show_text) + self.layout.addWidget(self.show_bounding_boxes) + self.layout.addWidget(self.tag_text_field) + self.layout.addWidget(self.checkpoint_path_label) + self.layout.addWidget(self.checkpoint_path) + self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed) + self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed) + self.layout.addWidget(self.precompute_button) + self.layout.addWidget(self.show_image) + self.layout.addWidget(self.show_visualization) + self.files = FilesHolder() + self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8) + + def on_delete_existing_annotation(self): + path = os.path.split(self.actual_file)[0] + basename = os.path.splitext(os.path.basename(self.actual_file))[0] + mask_path = os.path.join(path, basename + self.MASK_EXTENSION) + labels_path = os.path.join(path, basename + self.LABELS_EXTENSION) + if os.path.exists(mask_path): + os.remove(mask_path) + if os.path.exists(labels_path): + os.remove(labels_path) + + def is_show_text(self): + return self.show_text.isChecked() + + def on_show_text(self): + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def on_next_file(self): + file = self.files.get_next() + self._load_image(file) + + def on_previous_file(self): + file = self.files.get_previous() + self._load_image(file) + + def _load_image(self, file: str): + mask = file.split(".")[0] + self.MASK_EXTENSION + labels = file.split(".")[0] + self.LABELS_EXTENSION + bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION + image = cv2.imread(file, cv2.IMREAD_UNCHANGED) + self.actual_shape = image.shape[:2][::-1] + self.actual_file = file + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + else: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.dtype in [np.float32, np.float64, np.uint16]: + image = (image / np.amax(image) * 255).astype("uint8") + #image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1) + image = cv2.resize(image, + (int(self.parent().config.window_size[0]), self.parent().config.window_size[1])) + self.parent().annotator.clear() + self.parent().image_label.clear() + self.original_image = image.copy() + self.parent().set_image(image) + if os.path.exists(mask) and os.path.exists(labels): + self._load_annotation(mask, labels) + self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}") + self.parent().update(self.parent().annotator.merge_image_visualization()) + elif os.path.exists(bounding_boxes): + self._load_bounding_boxes(bounding_boxes) + self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}") + self.parent().update(self.parent().annotator.merge_image_visualization()) + else: + self.parent().info_label.setText(f"No annotation found! Image: {file}") + self.tag_text_field.setText("") + + def _load_annotation(self, mask, labels): + mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED) + mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]), + interpolation=cv2.INTER_NEAREST) + with open(labels, "r") as fp: + labels: dict[str, str] = json.load(fp) + masks = [] + new_labels = [] + if "instances" in labels: + instance_labels = labels["instances"] + else: + instance_labels = labels + + if "tags" in labels: + self.tag_text_field.setText(",".join(labels["tags"])) + else: + self.tag_text_field.setText("") + for str_index, class_ in instance_labels.items(): + single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) + single_mask[mask == int(str_index)] = 255 + masks.append(single_mask) + new_labels.append(class_) + self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels) + + def _load_bounding_boxes(self, bounding_boxes): + with open(bounding_boxes, "r") as f: + bounding_boxes: list[dict[str, float | str]] = json.load(f) + for bounding_box in bounding_boxes: + self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box)) + + def on_show_image(self): + self.parent().set_image(self.original_image, clear_annotations=False) + + def on_show_visualization(self): + self.parent().update(self.parent().annotator.merge_image_visualization()) + + def on_precompute(self): + pass + + def on_save_mask(self): + path = os.path.split(self.actual_file)[0] + tags = self.tag_text_field.text().split(",") + tags = [tag.strip() for tag in tags] + basename = os.path.splitext(os.path.basename(self.actual_file))[0] + mask_path = os.path.join(path, basename + self.MASK_EXTENSION) + labels_path = os.path.join(path, basename + self.LABELS_EXTENSION) + masks = self.parent().get_mask() + labels = {"instances": self.parent().get_labels(), "tags": tags} + with open(labels_path, "w") as f: + json.dump(labels, f, indent=4) + masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST) + cv2.imwrite(mask_path, masks) + + def on_checkpoint_path_changed(self): + self.parent().sam = self.parent().init_sam() + + def on_open_files(self): + files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)") + random.shuffle(files) + self.files.add_files(files) + self.on_next_file() + + def on_save_bounding_boxes(self): + path = os.path.split(self.actual_file)[0] + basename = pathlib.Path(self.actual_file).stem + bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION) + bounding_boxes = self.parent().get_bounding_boxes() + bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes] + with open(bounding_boxes_path, "w") as f: + json.dump(bounding_boxes_dict, f, indent=4) + + def is_show_bounding_boxes(self): + return self.show_bounding_boxes.isChecked() + + def on_show_bounding_boxes(self): + self.parent().update(self.parent().annotator.merge_image_visualization()) \ No newline at end of file diff --git a/segment_anything_ui/utils/__init__.py b/segment_anything_ui/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/segment_anything_ui/utils/bounding_boxes.py b/segment_anything_ui/utils/bounding_boxes.py new file mode 100644 index 0000000..4f9f03a --- /dev/null +++ b/segment_anything_ui/utils/bounding_boxes.py @@ -0,0 +1,55 @@ +import dataclasses +import numpy as np + + +@dataclasses.dataclass +class BoundingBox: + x_min: float + y_min: float + x_max: float + y_max: float + label: str + mask_uid: str = "" + + def to_dict(self): + return { + "x_min": self.x_min, + "y_min": self.y_min, + "x_max": self.x_max, + "y_max": self.y_max, + "label": self.label, + "mask_uid": self.mask_uid + } + + @property + def center(self): + return np.array([(self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2]) + + def distance_to(self, point: np.ndarray): + return np.linalg.norm(self.center - point) + + def contains(self, point: np.ndarray): + return self.x_min <= point[0] <= self.x_max and self.y_min <= point[1] <= self.y_max + + +def get_mask_bounding_box(mask, label: str): + where = np.where(mask) + x_min = np.min(where[1]) + y_min = np.min(where[0]) + x_max = np.max(where[1]) + y_max = np.max(where[0]) + return BoundingBox( + x_min / mask.shape[1], + y_min / mask.shape[0], + x_max / mask.shape[1], + y_max / mask.shape[0], + label + ) + + +def get_bounding_boxes(masks, labels): + bounding_boxes = [] + for mask, label in zip(masks, labels): + bounding_box = get_mask_bounding_box(mask, label) + bounding_boxes.append(bounding_box) + return bounding_boxes diff --git a/segment_anything_ui/utils/precompute_folder.py b/segment_anything_ui/utils/precompute_folder.py new file mode 100644 index 0000000..dc01e6a --- /dev/null +++ b/segment_anything_ui/utils/precompute_folder.py @@ -0,0 +1,26 @@ +import glob +import os + +import cv2 +import numpy as np +import torch +import rich +from PIL import Image +import safetensors +from segment_anything import sam_model_registry + +from segment_anything_ui.modeling.storable_sam import StorableSam +from segment_anything_ui.config import Config + +config = Config() +sam = sam_model_registry[config.get_sam_model_name()](checkpoint=config.default_weights) +allowed_extensions = [".jpg", ".png", ".tif", ".tiff"] + + +def load_images_from_folder(folder_path): + images = [] + for filename in os.listdir(folder_path): + allowed_extensions = [".jpg", ".png"] + if any(filename.endswith(ext) for ext in allowed_extensions): + img = Image.open(os.path.join(folder_path, filename)) + return images diff --git a/segment_anything_ui/utils/shapes.py b/segment_anything_ui/utils/shapes.py new file mode 100644 index 0000000..668d789 --- /dev/null +++ b/segment_anything_ui/utils/shapes.py @@ -0,0 +1,53 @@ +import dataclasses + +import cv2 +import numpy as np +from PySide6.QtCore import QPoint +from PySide6.QtGui import QPolygon + + +@dataclasses.dataclass +class BoundingBox: + xstart: float | int + ystart: float | int + xend: float | int = -1. + yend: float | int = -1. + + def to_numpy(self): + return np.array([self.xstart, self.ystart, self.xend, self.yend]) + + def scale(self, sx, sy): + return BoundingBox( + xstart=self.xstart * sx, + ystart=self.ystart * sy, + xend=self.xend * sx, + yend=self.yend * sy + ) + + def to_int(self): + return BoundingBox( + xstart=int(self.xstart), + ystart=int(self.ystart), + xend=int(self.xend), + yend=int(self.yend) + ) + +@dataclasses.dataclass +class Polygon: + points: list = dataclasses.field(default_factory=list) + + def to_numpy(self): + return np.array(self.points).reshape(-1, 2) + + def to_mask(self, num_rows, num_cols): + mask = np.zeros((num_rows, num_cols)) + mask = cv2.fillPoly(mask, pts=[self.to_numpy(), ], color=255) + return mask + + def is_plotable(self): + return len(self.points) > 3 + + def to_qpolygon(self): + return QPolygon([ + QPoint(x, y) for x, y in self.points + ]) diff --git a/segment_anything_ui/utils/tooltips.py b/segment_anything_ui/utils/tooltips.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e69de29