initial_tune

This commit is contained in:
AI-team\cyhan
2025-05-12 11:23:49 +09:00
commit b436a81091
33 changed files with 2398 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
*.pyc
data/
checkpoints/
__pycache__/
*.pt
*.pth
.venv/

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Branislav Hesko
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

121
README.md Normal file
View File

@@ -0,0 +1,121 @@
# Segment Anything UI
Simple UI for the model: [Segment anything](https://github.com/facebookresearch/segment-anything) from Facebook.
Segment anything UI for annotations
![GUI](./assets/example.png)
# Usage
1. Install segment-anything python package from Github: [Segment anything](https://github.com/facebookresearch/segment-anything). Usually it is enough to run: ```pip install git+https://github.com/facebookresearch/segment-anything.git```.
2. Download checkpoint [Checkpoint_Huge](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) or [Checkpoint_Large](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) or [Checkpoint_Base](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) and put it into workspace folder.
3. Fill default_path in ```segment_anything_ui/config.py```.
4. (Optional) Install efficientnet models ```pip install git+https://github.com/mit-han-lab/efficientvit```. See note below for Windows users about installing onnx!
5. (Optional) Install sam_hq models ```pip install segment-anything-hq```
5. Install requirements.txt. ```pip install -r requirements.txt```.
6. If on Ubuntu or Debian based distro, please use the following ```apt install libxkbcommon-x11-0 qt5dxcb-plugin libxcb-cursor0```. This will fix issues with Qt.
7. ```export PYTHONPATH=$PYTHONPATH:.```.
8. ```python segment_anything_ui/main_window.py```.
Currently, for saving a simple format is used: mask is saved as .png file, when masks are represented by values: 1 ... n and corresponding labels are saved as jsons. In json, labels are a map with mapping: MASK_ID: LABEL. MASK_ID is the id of the stored mask and LABEL is one of "labels.json" files.
```
For windows users, sometimes you will observe onnx used in EfficientVit is not easy to install using pip. In that case, it may be caused by
https://stackoverflow.com/questions/72352528/how-to-fix-winerror-206-the-filename-or-extension-is-too-long-error/76452218#76452218
To fix this error on your Windows machine on regedit and navigate to
Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem and edit LongPathsEnabled and set value from 0 to 1
Finally, fix onnx version, newest version seems to be broken
pip install onnx==1.15.0
```
# Checkpoints
Checkpoints are downloaded automatically if the model is not found in the workspace folder.
### SAM
- `vit_b`: [ViT-B SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
- `vit_h`: [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)
- `vit_l`: [ViT-L SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
### HQ-SAM
- `vit_b`: [ViT-B HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_b.pth)
- `vit_l`: [ViT-L HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_l.pth)
- `vit_h`: [ViT-H HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_h.pth)
- `vit_tiny` (**Light HQ-SAM** for real-time need): [ViT-Tiny HQ-SAM model](https://huggingface.co/lkeab/hq-sam/blob/main/sam_hq_vit_tiny.pth)
### EfficientViT
- `xl0`: [EfficientViT-XL0 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt)
- `xl1`: [EfficientViT-XL1 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt)
- `l2`: [EfficientViT-L2 model](https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt)
### SAM2
- `sam2.1_hiera_t`: [SAM2.1 Hiera Tiny model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_tiny.pt)
- `sam2.1_hiera_l`: [SAM2.1 Hiera Small model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_small.pt)
- `sam2.1_hiera_b+`: [SAM2.1 Hiera Base+ model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_base_plus.pt)
- `sam2.1_hiera_s`: [SAM2.1 Hiera Large model](https://dl.fbaipublicfiles.com/segment_anything_2/092824//sam2.1_hiera_large.pt)
# Functions
There are multiple functions that this UI implements. Few of them are:
* Add points - by left click of mouse button, you can add a positive point. By right click, you can add a negative point.
* Add boxes - a bounding box can be added to SAM annotation when the proper annotation tool is selected. Manual points, boxes and polygons in the future are used for SAM prediction.
* Add manual polygons - by clicking in the image using left mouse button (and selected manual annotation) manual annotation is done. It does not provide any other features right now.
* Instance mask is saved by clicking on "Save Annotation" button.
* Turn on bounding boxes - by clicking on "Turn on bounding boxes" button, you can turn on bounding boxes for the image.
* Save only bounding boxes - by clicking on "Save only bounding boxes" button, you can save only bounding boxes for the image.
* Annotate All - uses SAM to predict all masks by prompting the model with a grid of points followed by post-processing to refine masks.
* Pick Mask - Select a Mask from the image to delete it or inspect. Modifications are currently not allowed. As annotator allows multiple instances assigned to a pixel, Left clicking on the pixel cyclically changes between assigned masks.
* Each Instance mask is assigned a class according to the chosen label in the list. This list is loaded from labels.json file.
* Masks are inherently ordered based on the time of their creation, with earlier masks being dominant. Therefore masks that were annotated sooner are always present in the final mask, which is in form of an image with values 0 - N for N object instances. In the saved annotation, a pixel has only one value. A mask that is supposed to be used in the final annotation can be moved to the front by picking the desired mask and clicking on "Move Current Mask to Front" button. This is especially useful for "Annotate All" function.
* As each pixel can have multiple masks assigned, a lot of masks may be hidden. In this case only a few pixel could afterwards be present in the final annotation (mostly on the border) and therefore "Remove Hidden Masks" button can be used. This button removes all masks with hidden pixels (they are added later than the dominant mask) with a visible IOU less threshold.
* Masks annotations that have been "picked" can be deleted by Cancel current annotation button.
* Cancel annotation - Cancel current annotation - all points, bounding boxes etc...
* Partial mask - Some objects are hard to be automatically annotated - partial mask allows annotating a single instance by parts: each time a partial instance mask is annotated by clicking on the corresponding button the partial mask is enlarged and merged. Finally the final instance mask is given as a union of all partial masks.
* Zoom - With zoom tool a user can zoom onto some part of the image, annotate it and then this annotation is propagated to the whole image.
* Tags - Each image can have a tag assigned to it. This tag is saved in the annotation file and can be used for filtering images. Use comma separated values for multiple tags.
* When holding MOUSE key, mask proposal is shown. This is useful for eliminating duplicite clicking.
# Buttons
Please Note that buttons are fully configurable. You can change bindings and add new buttons in **segment_anything_ui/config.py**.
| **Button** | **Description** | **Shortcut** |
| --- | --- | --- |
| Add Points | Mouse click will add positive (left) or negative (right) point. | W |
| Add Box | Mouse click will add a box. | Q |
| Annotate All | Runs regular grid annotation with parameters from the form | Enter |
| Pick Mask | Pick mask when clicking on it. Cycling through masks if pixel belongs to multiple masks. | X |
| Merge Masks | WIP: Merge masks when clicking on them. Cycling through masks if pixel belongs to multiple masks. | Z |
| Move Current Mask to Front | Use current mask as background (less important) | None |
| Cancel Annotation | Cancel current annotation | C |
| Save Annotation | Save current annotation | S |
| Manual Polygon | Draw polygon with mouse | R |
| Partial Mask | Allows to split a single object into multiple prompts. When pressed, partial mask is stored and summed with the following promped masks. | D |
| Remove Hidden Masks | Remove masks that have hidden pixels. | None |
| Zoom to Rectangle | Zoom in on the image. | E |
| ---- | ---- | ---- |
| Open Files | Load files from the folder | None |
| Next File | Load next image | F |
| Previous File | Load previous image | G |
| ---- | ---- | ---- |
| Precompute All Embeddings | Currently not implemented | None |
| Show Image | Currently not implemented | None |
| Show Visualization | Currently not implemented | None |
# TODO:
- [x] - FIX: mouse picker for small objects is not precise.
- [ ] - Region merging.
- [x] - Manual annotation.
- [x] - Saving and loading of masks.
- [x] - Class support for assigning classes to objects.
- [x] - Add object borders.
- [x] - Fix mask size and QLabel size for precise mouse clicks.
- [ ] - Draft mask when no points are visible.
- [x] - Box zoom support.

BIN
assets/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 MiB

6
install_script Normal file
View File

@@ -0,0 +1,6 @@
uv venv
./venv/bin/activate
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
uv pip install git+https://github.com/facebookresearch/segment-anything.git
uv pip install git+https://github.com/facebookresearch/sam2.git
uv pip install -r requirements.txt

4
labels.json Normal file
View File

@@ -0,0 +1,4 @@
{
"PRODUCT": 1,
"SEPARATOR": 2
}

11
requirements.txt Normal file
View File

@@ -0,0 +1,11 @@
pyside6
opencv-python
torch
torchvision
numpy
segment-anything
matplotlib
scikit-image
tqdm
wget
requests

2
run_ui.bat Normal file
View File

@@ -0,0 +1,2 @@
set PYTHONPATH=%PYTHONPATH%;%~dp0
python segment_anything_ui/main_window.py

2
run_ui.sh Normal file
View File

@@ -0,0 +1,2 @@
export PYTHONPATH=$PYTHONPATH:.
python3 segment_anything_ui/main_window.py

View File

View File

@@ -0,0 +1,218 @@
import enum
import json
import os
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QLineEdit, QListWidget, QMessageBox
from segment_anything_ui.draw_label import PaintType
from segment_anything_ui.annotator import AutomaticMaskGeneratorSettings, CustomForm, MasksAnnotation
class MergeState(enum.Enum):
PICKING = enum.auto()
MERGING = enum.auto()
class AnnotationLayout(QWidget):
def __init__(self, parent, config) -> None:
super().__init__(parent)
self.config = config
self.zoom_flag = False
self.merge_state = MergeState.PICKING
self.layout = QVBoxLayout(self)
labels = self._load_labels(config)
self.layout.setAlignment(Qt.AlignTop)
self.add_point = QPushButton(f"Add Point [ {config.key_mapping.ADD_POINT.name} ]")
self.add_box = QPushButton(f"Add Box [ {config.key_mapping.ADD_BOX.name} ]")
self.annotate_all = QPushButton(f"Annotate All [ {config.key_mapping.ANNOTATE_ALL.name} ]")
self.manual_polygon = QPushButton(f"Manual Polygon [ {config.key_mapping.MANUAL_POLYGON.name} ]")
self.cancel_annotation = QPushButton(f"Cancel Annotation [ {config.key_mapping.CANCEL_ANNOTATION.name} ]")
self.save_annotation = QPushButton(f"Save Annotation [ {config.key_mapping.SAVE_ANNOTATION.name} ]")
self.pick_mask = QPushButton(f"Pick Mask [ {config.key_mapping.PICK_MASK.name} ]")
self.pick_bounding_box = QPushButton(f"Pick Bounding Box [ {config.key_mapping.PICK_BOUNDING_BOX.name} ]")
self.merge_masks = QPushButton(f"Merge Masks [ {config.key_mapping.MERGE_MASK.name} ]")
self.delete_mask = QPushButton(f"Delete Mask [ {config.key_mapping.DELETE_MASK.name} ]")
self.partial_annotation = QPushButton(f"Partial Annotation [ {config.key_mapping.PARTIAL_ANNOTATION.name} ]")
self.zoom_rectangle = QPushButton(f"Zoom Rectangle [ {config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.label_picker = QListWidget()
self.label_picker.addItems(labels)
self.label_picker.setCurrentRow(0)
self.move_current_mask_background = QPushButton("Move Current Mask to Front")
self.remove_hidden_masks = QPushButton("Remove Hidden Masks")
self.remove_hidden_masks_label = QLabel("Remove Hidden Masks with hidden area less than")
self.remove_hidden_masks_line = QLineEdit("0.5")
self.save_annotation.setShortcut(config.key_mapping.SAVE_ANNOTATION.key)
self.add_point.setShortcut(config.key_mapping.ADD_POINT.key)
self.add_box.setShortcut(config.key_mapping.ADD_BOX.key)
self.annotate_all.setShortcut(config.key_mapping.ANNOTATE_ALL.key)
self.cancel_annotation.setShortcut(config.key_mapping.CANCEL_ANNOTATION.key)
self.pick_mask.setShortcut(config.key_mapping.PICK_MASK.key)
self.pick_bounding_box.setShortcut(config.key_mapping.PICK_BOUNDING_BOX.key)
self.partial_annotation.setShortcut(config.key_mapping.PARTIAL_ANNOTATION.key)
self.delete_mask.setShortcut(config.key_mapping.DELETE_MASK.key)
self.zoom_rectangle.setShortcut(config.key_mapping.ZOOM_RECTANGLE.key)
self.annotation_settings = CustomForm(self, AutomaticMaskGeneratorSettings())
for w in [
self.add_point,
self.add_box,
self.annotate_all,
self.pick_mask,
self.pick_bounding_box,
self.merge_masks,
self.move_current_mask_background,
self.cancel_annotation,
self.delete_mask,
self.partial_annotation,
self.save_annotation,
self.manual_polygon,
self.label_picker,
self.zoom_rectangle,
self.annotation_settings,
self.remove_hidden_masks,
self.remove_hidden_masks_label,
self.remove_hidden_masks_line
]:
self.layout.addWidget(w)
self.add_point.clicked.connect(self.on_add_point)
self.add_box.clicked.connect(self.on_add_box)
self.annotate_all.clicked.connect(self.on_annotate_all)
self.cancel_annotation.clicked.connect(self.on_cancel_annotation)
self.save_annotation.clicked.connect(self.on_save_annotation)
self.pick_mask.clicked.connect(self.on_pick_mask)
self.pick_bounding_box.clicked.connect(self.on_pick_bounding_box)
self.manual_polygon.clicked.connect(self.on_manual_polygon)
self.remove_hidden_masks.clicked.connect(self.on_remove_hidden_masks)
self.move_current_mask_background.clicked.connect(self.on_move_current_mask_background_fn)
self.merge_masks.clicked.connect(self.on_merge_masks)
self.partial_annotation.clicked.connect(self.on_partial_annotation)
self.delete_mask.clicked.connect(self.on_delete_mask)
self.zoom_rectangle.clicked.connect(self.on_zoom_rectangle)
def on_delete_mask(self):
if self.parent().image_label.paint_type == PaintType.MASK_PICKER:
self.parent().info_label.setText("Deleting mask!")
mask_uid = self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.bounding_boxes.remove(mask_uid)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
elif self.parent().image_label.paint_type == PaintType.BOX_PICKER:
self.parent().info_label.setText("Deleting bounding box!")
mask_uid = self.parent().annotator.bounding_boxes.remove_by_id(
self.parent().annotator.bounding_boxes.bounding_box_id)
if mask_uid is not None:
self.parent().annotator.masks.pop_by_uuid(mask_uid)
self.parent().annotator.bounding_boxes.bounding_box_id = -1
self.parent().annotator.last_mask = None
self.parent().annotator.masks.mask_id = -1
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
QMessageBox.warning(self, "Error", "Please pick a mask or bounding box to delete!")
def on_partial_annotation(self):
self.parent().info_label.setText("Partial annotation!")
self.parent().annotator.pick_partial_mask()
self.parent().image_label.clear()
@staticmethod
def _load_labels(config):
if not os.path.exists(config.label_file):
return ["default"]
with open(config.label_file, "r") as f:
labels = json.load(f)
MasksAnnotation.DEFAULT_LABEL = list(labels.keys())[0] if len(labels) > 0 else "default"
return labels
def on_merge_masks(self):
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
if self.merge_state == MergeState.PICKING:
self.parent().info_label.setText("Pick a mask to merge with!")
self.merge_state = MergeState.MERGING
self.parent().annotator.merged_mask = self.parent().annotator.last_mask.copy()
elif self.merge_state == MergeState.MERGING:
self.parent().info_label.setText("Merging masks!")
self.parent().annotator.merge_masks()
self.merge_state = MergeState.PICKING
def on_move_current_mask_background_fn(self):
self.parent().info_label.setText("Moving current mask to background!")
self.parent().annotator.move_current_mask_to_background()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_remove_hidden_masks(self):
self.parent().info_label.setText("Removing hidden masks!")
annotations = self.parent().annotator.masks
argmax_mask = self.parent().annotator.make_instance_mask()
limit_ratio = float(self.remove_hidden_masks_line.text())
new_masks = []
new_labels = []
for idx, (mask, label) in enumerate(annotations):
num_pixels = np.sum(mask > 0)
num_visible = np.sum(argmax_mask == (idx + 1))
ratio = num_visible / num_pixels
if ratio > limit_ratio:
new_masks.append(mask)
new_labels.append(label)
print("Removed ", len(annotations) - len(new_masks), " masks.")
self.parent().annotator.masks = MasksAnnotation.from_masks(new_masks, new_labels)
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_pick_mask(self):
self.parent().info_label.setText("Pick a mask to do modifications!")
self.parent().image_label.change_paint_type(PaintType.MASK_PICKER)
def on_pick_bounding_box(self):
self.parent().info_label.setText("Pick a bounding box to do modifications!")
self.parent().image_label.change_paint_type(PaintType.BOX_PICKER)
def on_manual_polygon(self):
# Sets emphasis on the button
self.manual_polygon.setProperty("active", True)
self.parent().image_label.change_paint_type(PaintType.POLYGON)
def on_add_point(self):
self.parent().info_label.setText("Adding point annotation!")
self.parent().image_label.change_paint_type(PaintType.POINT)
def on_add_box(self):
self.parent().info_label.setText("Adding box annotation!")
self.parent().image_label.change_paint_type(PaintType.BOX)
def on_zoom_rectangle(self):
if self.zoom_flag:
self.parent().info_label.setText("Zooming rectangle OFF!")
self.parent().image_label.change_paint_type(PaintType.POINT)
self.parent().annotator.zoomed_bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self.zoom_flag = False
else:
self.parent().info_label.setText("Pick Mask to zoom!")
self.zoom_rectangle.setText(f"Zoom Rectangle [ {self.config.key_mapping.ZOOM_RECTANGLE.name} ]")
self.parent().image_label.change_paint_type(PaintType.ZOOM_PICKER)
self.zoom_flag = True
def on_annotate_all(self):
self.parent().info_label.setText("Annotating all. This make take some time.")
self.parent().annotator.predict_all(self.annotation_settings.get_values())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().info_label.setText("Annotate all finished.")
def on_cancel_annotation(self):
self.parent().image_label.clear()
self.parent().annotator.clear_last_masks()
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_save_annotation(self):
if self.parent().image_label.paint_type == PaintType.POLYGON:
self.parent().annotator.last_mask = self.parent().image_label.polygon.to_mask(
self.config.window_size[0], self.config.window_size[1])
self.parent().annotator.save_mask(label=self.label_picker.currentItem().text())
self.parent().update(self.parent().annotator.merge_image_visualization())
self.parent().image_label.clear()

View File

@@ -0,0 +1,445 @@
import dataclasses
from typing import Callable
import uuid
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QLineEdit
from segment_anything import SamPredictor
from segment_anything.build_sam import Sam
from segment_anything_ui.model_builder import (
get_predictor, get_mask_generator, SamPredictor)
try:
from segment_anything_ui.model_builder import EfficientViTSamPredictor, EfficientViTSam
except (ImportError, ModuleNotFoundError):
class EfficientViTSamPredictor:
pass
class EfficientViTSam:
pass
from skimage.measure import regionprops
import torch
from segment_anything_ui.utils.shapes import BoundingBox
from segment_anything_ui.utils.bounding_boxes import get_bounding_boxes, get_mask_bounding_box
def get_cmap(n, name='hsv'):
'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
RGB color; the keyword argument name must be a standard mpl colormap name.'''
try:
return plt.cm.get_cmap(name, n)
except:
return plt.get_cmap(name, n)
def crop_image(
image,
box: BoundingBox | None = None,
image_shape: tuple[int, int] | None = None
):
if image_shape is None:
image_shape = image.shape[:2][::-1]
if box is None:
return cv2.resize(image, image_shape)
if len(image.shape) == 2:
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend], image_shape)
return cv2.resize(image[box.ystart:box.yend, box.xstart:box.xend, :], image_shape)
def insert_image(image, box: BoundingBox | None = None):
new_image = np.zeros_like(image)
if box is None:
new_image = image
else:
if len(image.shape) == 2:
new_image[box.ystart:box.yend, box.xstart:box.xend] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
else:
new_image[box.ystart:box.yend, box.xstart:box.xend, :] = cv2.resize(
image.astype(np.uint8), (int(box.xend) - int(box.xstart), int(box.yend) - int(box.ystart)))
return new_image
@dataclasses.dataclass()
class AutomaticMaskGeneratorSettings:
points_per_side: int = 32
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
class LabelValueParam(QWidget):
def __init__(self, label_text, default_value, value_type_converter: Callable = lambda x: x, parent=None):
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.label = QLabel(self, text=label_text, alignment=Qt.AlignCenter)
self.value = QLineEdit(self, text=default_value, alignment=Qt.AlignCenter)
self.layout.addWidget(self.label)
self.layout.addWidget(self.value)
self.converter = value_type_converter
def get_value(self):
return self.converter(self.value.text())
class CustomForm(QWidget):
def __init__(self, parent: QWidget, automatic_mask_generator_settings: AutomaticMaskGeneratorSettings) -> None:
super().__init__(parent)
self.layout = QVBoxLayout(self)
self.layout.setSpacing(0)
self.layout.setContentsMargins(0, 0, 0, 0)
self.widgets = []
for field in dataclasses.fields(automatic_mask_generator_settings):
widget = LabelValueParam(field.name, str(field.default), field.type)
self.widgets.append(widget)
self.layout.addWidget(widget)
def get_values(self):
return AutomaticMaskGeneratorSettings(**{widget.label.text(): widget.get_value() for widget in self.widgets})
class BoundingBoxAnnotation:
def __init__(self) -> None:
self.bounding_boxes: list[BoundingBox] = []
self.bounding_box_id: int = -1
def append(self, bounding_box: BoundingBox):
self.bounding_boxes.append(bounding_box)
def find_closest_bounding_box(self, point: np.ndarray):
closest_bounding_box = None
closest_bounding_box_id = -1
min_distance = float('inf')
for idx, bounding_box in enumerate(self.bounding_boxes):
distance = bounding_box.distance_to(point)
if distance < min_distance and bounding_box.contains(point):
min_distance = distance
closest_bounding_box = bounding_box
closest_bounding_box_id = idx
self.bounding_box_id = closest_bounding_box_id
return closest_bounding_box, closest_bounding_box_id
def get_bounding_box(self, bounding_box_id: int):
return self.bounding_boxes[bounding_box_id]
def get_current_bounding_box(self):
return self.bounding_boxes[self.bounding_box_id]
def set_current_bounding_box(self, bounding_box: BoundingBox):
self.bounding_boxes[self.bounding_box_id] = bounding_box
def remove(self, mask_uid: str):
bounding_box_id = next((idx for idx, bounding_box in enumerate(self.bounding_boxes) if bounding_box.mask_uid == mask_uid), None)
if bounding_box_id is None:
return
bounding_box = self.bounding_boxes.pop(bounding_box_id)
if self.bounding_box_id >= bounding_box_id:
self.bounding_box_id -= 1
return bounding_box
def remove_by_id(self, bounding_box_id: int):
mask_uid = self.bounding_boxes[bounding_box_id].mask_uid
self.remove(mask_uid)
return mask_uid
def __len__(self):
return len(self.bounding_boxes)
class MasksAnnotation:
DEFAULT_LABEL = "default"
def __init__(self) -> None:
self.masks = []
self.label_map = {}
self.masks_uids: list[str] = []
self.mask_id: int = -1
def add_mask(self, mask, label: str | None = None):
self.masks.append(mask)
self.masks_uids.append(str(uuid.uuid4()))
self.label_map[len(self.masks)] = self.DEFAULT_LABEL if label is None else label
return self.masks_uids[-1]
def add_label(self, mask_id: int, label: str):
self.label_map[mask_id] = label
def get_mask(self, mask_id: int):
return self.masks[mask_id]
def get_label(self, mask_id: int):
return self.label_map[mask_id]
def get_current_mask(self):
return self.masks[self.mask_id]
def set_current_mask(self, mask, label: str = None):
self.masks[self.mask_id] = mask
self.label_map[self.mask_id] = self.DEFAULT_LABEL if label is None else label
def __getitem__(self, mask_id: int):
return self.get_mask(mask_id)
def __setitem__(self, mask_id: int, value):
self.masks[mask_id] = value
def __len__(self):
return len(self.masks)
def __iter__(self):
return iter(zip(self.masks, self.label_map.values()))
def __next__(self):
if self.mask_id >= len(self.masks):
raise StopIteration
return self.masks[self.mask_id]
def append(self, mask, label: str | None = None):
return self.add_mask(mask, label)
def pop_by_uuid(self, mask_uid: str):
mask_id = next((idx for idx, m_uid in enumerate(self.masks_uids) if m_uid == mask_uid), None)
if mask_id is None:
return
return self.pop(mask_id)
def pop(self, mask_id: int = -1):
_ = self.masks.pop(mask_id)
mask_uid = self.masks_uids.pop(mask_id)
self.label_map.pop(mask_id + 1)
new_label_map = {}
for index, value in enumerate(self.label_map.values()):
new_label_map[index + 1] = value
self.label_map = new_label_map
return mask_uid
@classmethod
def from_masks(cls, masks, labels: list[str] | None = None):
annotation = cls()
if labels is None:
labels = [None] * len(masks)
for mask, label in zip(masks, labels):
annotation.append(mask, label)
return annotation
@dataclasses.dataclass()
class Annotator:
sam: Sam | EfficientViTSam | None = None
embedding: torch.Tensor | None = None
image: np.ndarray | None = None
masks: MasksAnnotation = dataclasses.field(default_factory=MasksAnnotation)
bounding_boxes: BoundingBoxAnnotation = dataclasses.field(default_factory=BoundingBoxAnnotation)
predictor: SamPredictor | EfficientViTSamPredictor | None = None
visualization: np.ndarray | None = None
last_mask: np.ndarray | None = None
partial_mask: np.ndarray | None = None
merged_mask: np.ndarray | None = None
parent: QWidget | None = None
cmap: plt.cm = None
original_image: np.ndarray | None = None
zoomed_bounding_box: BoundingBox | None = None
def __post_init__(self):
self.MAX_MASKS = 10
self.cmap = get_cmap(self.MAX_MASKS)
def set_image(self, image: np.ndarray):
self.image = image
return self
def make_embedding(self):
if self.sam is None:
return
self.predictor = get_predictor(self.sam)
self.predictor.set_image(crop_image(self.image, self.zoomed_bounding_box))
def predict_all(self, settings: AutomaticMaskGeneratorSettings):
generator = get_mask_generator(
sam=self.sam,
**dataclasses.asdict(settings)
)
masks = generator.generate(self.image)
masks = [(m["segmentation"] * 255).astype(np.uint8) for m in masks]
label = self.parent.annotation_layout.label_picker.currentItem().text()
self.masks = MasksAnnotation.from_masks(masks, [label, ] * len(masks))
self.cmap = get_cmap(len(self.masks))
def make_prediction(self, annotation: dict):
masks, scores, logits = self.predictor.predict(
point_coords=annotation["points"],
point_labels=annotation["labels"],
box=annotation["bounding_boxes"],
multimask_output=False
)
mask = masks[0]
self.last_mask = insert_image(mask, self.zoomed_bounding_box) * 255
def pick_partial_mask(self):
if self.partial_mask is None:
self.partial_mask = self.last_mask.copy()
else:
self.partial_mask = np.maximum(self.last_mask, self.partial_mask)
self.last_mask = None
def move_current_mask_to_background(self):
self.masks.set_current_mask(self.masks.get_current_mask() * 0.5)
def merge_masks(self):
new_mask = np.bitwise_or(self.last_mask, self.merged_mask)
self.masks.set_current_mask(new_mask, self.parent.annotation_layout.label_picker.currentItem().text())
self.merged_mask = None
def visualize_last_mask(self, label: str | None = None):
last_mask = np.zeros_like(self.image)
last_mask[:, :, 1] = self.last_mask
if self.partial_mask is not None:
last_mask[:, :, 0] = self.partial_mask
if self.merged_mask is not None:
last_mask[:, :, 2] = self.merged_mask
if label is not None:
props = regionprops(self.last_mask)[0]
cv2.putText(
last_mask,
label,
(int(props.centroid[1]), int(props.centroid[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
if self.is_show_bounding_boxes:
last_mask_bounding_boxes = get_mask_bounding_box(last_mask[:, :, 1], label)
cv2.rectangle(
last_mask,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
(int(last_mask_bounding_boxes.x_max * self.image.shape[1]), int(last_mask_bounding_boxes.y_max * self.image.shape[0])),
(0, 255, 0),
2
)
cv2.putText(
last_mask,
label,
(int(last_mask_bounding_boxes.x_min * self.image.shape[1]), int(last_mask_bounding_boxes.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
visualization = cv2.addWeighted(self.image.copy() if self.visualization is None else self.visualization.copy(),
0.8, last_mask, 0.5, 0)
self.parent.update(crop_image(visualization, self.zoomed_bounding_box))
def _visualize_mask(self) -> tuple:
mask_argmax = self.make_instance_mask()
visualization = np.zeros_like(self.image)
border = np.zeros(self.image.shape[:2], dtype=np.uint8)
for i in range(1, np.amax(mask_argmax) + 1):
color = self.cmap(i)
single_mask = np.zeros_like(mask_argmax)
single_mask[mask_argmax == i] = 1
visualization[mask_argmax == i, :] = np.array(color[:3]) * 255
border += single_mask - cv2.erode(
single_mask, np.ones((3, 3), np.uint8), iterations=1)
label = self.masks.get_label(i)
single_mask_center = np.mean(np.where(single_mask == 1), axis=1)
if np.isnan(single_mask_center).any():
continue
if self.parent.settings.is_show_text():
cv2.putText(
visualization,
label,
(int(single_mask_center[1]), int(single_mask_center[0])),
cv2.FONT_HERSHEY_PLAIN,
0.5,
[255, 255, 255],
1
)
if self.is_show_bounding_boxes:
bounding_boxes = self.get_bounding_boxes()
for idx, bounding_box in enumerate(bounding_boxes):
cv2.rectangle(
visualization,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
(int(bounding_box.x_max * self.image.shape[1]), int(bounding_box.y_max * self.image.shape[0])),
(0, 0, 255) if idx != self.bounding_boxes.bounding_box_id else (0, 255, 0),
2
)
cv2.putText(
visualization,
bounding_box.label,
(int(bounding_box.x_min * self.image.shape[1]), int(bounding_box.y_min * self.image.shape[0])),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
[255, 255, 255],
2
)
border = (border == 0).astype(np.uint8)
return visualization, border
def has_annotations(self):
return len(self.masks) > 0
def make_instance_mask(self):
background = np.zeros_like(self.masks[0]) + 1
mask_argmax = np.argmax(np.concatenate([np.expand_dims(background, 0), np.array(self.masks.masks)], axis=0), axis=0).astype(np.uint8)
return mask_argmax
def get_bounding_boxes(self):
return get_bounding_boxes(self.masks.masks, self.masks.label_map.values())
def merge_image_visualization(self):
image = self.image.copy()
if not len(self.masks):
return crop_image(image, self.zoomed_bounding_box)
visualization, border = self._visualize_mask()
self.visualization = cv2.addWeighted(image, 0.8, visualization, 0.7, 0) * border[:, :, np.newaxis]
return crop_image(self.visualization, self.zoomed_bounding_box)
def remove_last_mask(self):
mask_id = len(self.masks)
self.masks.pop(mask_id)
self.bounding_boxes.remove(mask_id)
def make_labels(self):
return self.masks.label_map
def save_mask(self, label: str = MasksAnnotation.DEFAULT_LABEL):
if self.partial_mask is not None:
last_mask = self.partial_mask
self.partial_mask = None
else:
last_mask = self.last_mask
mask_uid = self.masks.add_mask(last_mask, label=label)
corresponding_bounding_box = get_mask_bounding_box(last_mask, label)
corresponding_bounding_box.mask_uid = mask_uid
self.bounding_boxes.append(corresponding_bounding_box)
if len(self.masks) >= self.MAX_MASKS:
self.MAX_MASKS += 10
self.cmap = get_cmap(self.MAX_MASKS)
@property
def is_show_bounding_boxes(self):
return self.parent.settings.is_show_bounding_boxes()
def clear_last_masks(self):
self.last_mask = None
self.partial_mask = None
self.visualization = None
def clear(self):
self.last_mask = None
self.visualization = None
self.masks = MasksAnnotation()
self.bounding_boxes = BoundingBoxAnnotation()
self.partial_mask = None

View File

@@ -0,0 +1,143 @@
import dataclasses
import os
from typing import Literal
from PySide6.QtCore import Qt
import requests
try:
from tqdm import tqdm
import wget
except ImportError:
print("Tqdm and wget not found. Install with pip install tqdm wget")
tqdm = None
wget = None
@dataclasses.dataclass(frozen=True)
class Keymap:
key: Qt.Key | str
name: str
class ProgressBar:
def __init__(self):
self.progress_bar = None
def __call__(self, current_bytes, total_bytes, width):
current_mb = round(current_bytes / 1024 ** 2, 1)
total_mb = round(total_bytes / 1024 ** 2, 1)
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_mb, desc="MB")
delta_mb = current_mb - self.progress_bar.n
self.progress_bar.update(delta_mb)
@dataclasses.dataclass
class KeyBindings:
ADD_POINT: Keymap = Keymap(Qt.Key.Key_W, "W")
ADD_BOX: Keymap = Keymap(Qt.Key.Key_Q, "Q")
ANNOTATE_ALL: Keymap = Keymap(Qt.Key.Key_Return, "Enter")
MANUAL_POLYGON: Keymap = Keymap(Qt.Key.Key_R, "R")
CANCEL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_C, "C")
SAVE_ANNOTATION: Keymap = Keymap(Qt.Key.Key_S, "S")
PICK_MASK: Keymap = Keymap(Qt.Key.Key_X, "X")
PICK_BOUNDING_BOX: Keymap = Keymap(Qt.Key.Key_B, "B")
MERGE_MASK: Keymap = Keymap(Qt.Key.Key_Z, "Z")
DELETE_MASK: Keymap = Keymap(Qt.Key.Key_V, "V")
PARTIAL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_D, "D")
SAVE_BOUNDING_BOXES: Keymap = Keymap("Ctrl+B", "Ctrl+B")
NEXT_FILE: Keymap = Keymap(Qt.Key.Key_F, "F")
PREVIOUS_FILE: Keymap = Keymap(Qt.Key.Key_G, "G")
SAVE_MASK: Keymap = Keymap("Ctrl+S", "Ctrl+S")
PRECOMPUTE: Keymap = Keymap(Qt.Key.Key_P, "P")
ZOOM_RECTANGLE: Keymap = Keymap(Qt.Key.Key_E, "E")
@dataclasses.dataclass
class Config:
default_weights: Literal[
"sam_vit_b_01ec64.pth",
"sam_vit_h_4b8939.pth",
"sam_vit_l_0b3195.pth",
"xl0.pt",
"xl1.pt",
"sam_hq_vit_b.pth",
"sam_hq_vit_l.pth",
"sam_hq_vit_h.pth",
"sam_hq_vit_tiny.pth",
"sam2.1_hiera_t.pth",
"sam2.1_hiera_l.pth",
"sam2.1_hiera_b+.pth",
"sam2.1_hiera_s.pth",
] = "sam_vit_h_4b8939.pth"
download_weights_if_not_available: bool = True
label_file: str = "labels.json"
window_size: tuple[int, int] | int = (1920, 1080)
key_mapping: KeyBindings = dataclasses.field(default_factory=KeyBindings)
weights_paths: dict[str, str] = dataclasses.field(default_factory=lambda: {
"l2": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"xl0": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt",
"xl1": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt",
"hq_vit_b": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
"hq_vit_l": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
"hq_vit_h": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"hq_vit_tiny": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
"sam2.1_hiera_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"sam2.1_hiera_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"sam2.1_hiera_b+": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"sam2.1_hiera_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
})
def __post_init__(self):
if isinstance(self.window_size, int):
self.window_size = (self.window_size, self.window_size)
if self.download_weights_if_not_available:
self.download_weights()
def get_sam_model_name(self):
if "l2" in self.default_weights:
return "l2"
if "sam_vit_b" in self.default_weights:
return "vit_b"
if "sam_vit_h" in self.default_weights:
return "vit_h"
if "sam_vit_l" in self.default_weights:
return "vit_l"
if "xl0" in self.default_weights:
return "xl0"
if "xl1" in self.default_weights:
return "xl1"
if "hq_vit_b" in self.default_weights:
return "hq_vit_b"
if "hq_vit_l" in self.default_weights:
return "hq_vit_l"
if "hq_vit_h" in self.default_weights:
return "hq_vit_h"
if "hq_vit_tiny" in self.default_weights:
return "hq_vit_tiny"
if "sam2.1_hiera_t" in self.default_weights:
return "sam2.1_hiera_t"
if "sam2.1_hiera_l" in self.default_weights:
return "sam2.1_hiera_l"
if "sam2.1_hiera_b+" in self.default_weights:
return "sam2.1_hiera_b+"
if "sam2.1_hiera_s" in self.default_weights:
return "sam2.1_hiera_s"
raise ValueError("Unknown model name")
def download_weights(self):
if not os.path.exists(self.default_weights):
try:
print(f"Downloading weights for model {self.get_sam_model_name()}")
wget.download(self.weights_paths[self.get_sam_model_name()], self.default_weights, bar=ProgressBar())
print(f"Downloaded weights to {self.default_weights}")
except Exception as e:
print(f"Error downloading weights: {e}. Trying with requests.")
model_name = self.get_sam_model_name()
print(f"Downloading weights for model {model_name}")
file = requests.get(self.weights_paths[model_name])
with open(self.default_weights, "wb") as f:
f.write(file.content)
print(f"Downloaded weights to {self.default_weights}")

View File

@@ -0,0 +1,268 @@
import copy
from enum import Enum
import cv2
import numpy as np
import PySide6
from PySide6 import QtCore, QtWidgets
from PySide6.QtGui import QPainter, QPen
from segment_anything_ui.config import Config
from segment_anything_ui.utils.shapes import BoundingBox, Polygon
class PaintType(Enum):
POINT = 0
BOX = 1
MASK = 2
POLYGON = 3
MASK_PICKER = 4
ZOOM_PICKER = 5
BOX_PICKER = 6
class MaskIdPicker:
def __init__(self, length) -> None:
self.counter = 0
self.length = length
def increment(self):
self.counter = (self.counter + 1) % self.length
def pick(self, ids):
print("Length of ids: ", len(ids), " counter: ", self.counter, " ids: ", ids)
if len(ids) <= self.counter:
self.counter = 0
return_id = ids[self.counter]
self.increment()
return return_id
class DrawLabel(QtWidgets.QLabel):
def __init__(self, parent=None):
super().__init__(parent)
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0)
self._paint_type = PaintType.POINT
self.polygon = Polygon()
self.mask_enum: MaskIdPicker = MaskIdPicker(3)
self.config = Config()
self.setFocusPolicy(QtCore.Qt.StrongFocus)
self._zoom_center = (0, 0)
self._zoom_factor = 1.0
self._zoom_bounding_box: BoundingBox | None = None
def paintEvent(self, paint_event):
painter = QPainter(self)
painter.drawPixmap(self.rect(), self.pixmap())
pen_positive = self._get_pen(QtCore.Qt.green, 5)
pen_negative = self._get_pen(QtCore.Qt.red, 5)
pen_partial = self._get_pen(QtCore.Qt.yellow, 1)
pen_box = self._get_pen(QtCore.Qt.green, 1)
painter.setRenderHint(QPainter.Antialiasing, False)
painter.setPen(pen_box)
if self.bounding_box is not None and self.bounding_box.xend != -1 and self.bounding_box.yend != -1:
painter.drawRect(
self.bounding_box.xstart,
self.bounding_box.ystart,
self.bounding_box.xend - self.bounding_box.xstart,
self.bounding_box.yend - self.bounding_box.ystart
)
painter.setPen(pen_partial)
painter.drawRect(self.partial_box.xstart, self.partial_box.ystart,
self.partial_box.xend - self.partial_box.xstart,
self.partial_box.yend - self.partial_box.ystart)
painter.setPen(pen_positive)
for pos in self.positive_points:
painter.drawPoint(pos)
painter.setPen(pen_negative)
painter.setRenderHint(QPainter.Antialiasing, False)
for pos in self.negative_points:
painter.drawPoint(pos)
if self.polygon.is_plotable():
painter.setPen(pen_box)
painter.setRenderHint(QPainter.Antialiasing, True)
painter.drawPolygon(self.polygon.to_qpolygon())
# self.update()
def _get_pen(self, color=QtCore.Qt.red, width=5):
pen = QPen()
pen.setWidth(width)
pen.setColor(color)
return pen
@property
def paint_type(self):
return self._paint_type
def change_paint_type(self, paint_type: PaintType):
print(f"Changing paint type to {paint_type}")
self._paint_type = paint_type
def mouseMoveEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
self.partial_box = copy.deepcopy(self.bounding_box)
self.partial_box.xend = ev.pos().x()
self.partial_box.yend = ev.pos().y()
self.update()
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.buttons() == QtCore.Qt.LeftButton:
self._move_update(None, point)
elif ev.buttons() == QtCore.Qt.RightButton:
self._move_update(point, None)
else:
pass
self.update()
def _move_update(self, temporary_point_negative, temporary_point_positive):
annotations = self.get_annotations(temporary_point_positive, temporary_point_negative)
self.parent().annotator.make_prediction(annotations)
self.parent().annotator.visualize_last_mask()
def mouseReleaseEvent(self, cursor_event):
if self._paint_type == PaintType.POINT:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.positive_points.append(cursor_event.pos())
print(self.size())
elif cursor_event.button() == QtCore.Qt.RightButton:
self.negative_points.append(cursor_event.pos())
# self.chosen_points.append(self.mapFromGlobal(QtGui.QCursor.pos()))
elif self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER]:
if cursor_event.button() == QtCore.Qt.LeftButton:
self.bounding_box.xend = cursor_event.pos().x()
self.bounding_box.yend = cursor_event.pos().y()
self.partial_box = BoundingBox(-1, -1, -1, -1)
if not self._paint_type == PaintType.MASK_PICKER and \
not self._paint_type == PaintType.ZOOM_PICKER and \
not self._paint_type == PaintType.POLYGON and \
not self._paint_type == PaintType.BOX_PICKER:
self.parent().annotator.make_prediction(self.get_annotations())
self.parent().annotator.visualize_last_mask()
if self._paint_type == PaintType.ZOOM_PICKER:
self.parent().annotator.zoomed_bounding_box = self.bounding_box.scale(*self._get_scale()).to_int()
self.bounding_box = None
self.parent().annotator.make_embedding()
self.parent().update(self.parent().annotator.merge_image_visualization())
self._paint_type = PaintType.POINT
self.update()
def mousePressEvent(self, ev: PySide6.QtGui.QMouseEvent) -> None:
if self._paint_type in [PaintType.BOX, PaintType.ZOOM_PICKER] and ev.button() == QtCore.Qt.LeftButton:
self.bounding_box = BoundingBox(xstart=ev.pos().x(), ystart=ev.pos().y())
if self._paint_type == PaintType.POLYGON and ev.button() == QtCore.Qt.LeftButton:
self.polygon.points.append([ev.pos().x(), ev.pos().y()])
if self._paint_type == PaintType.MASK_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
int(ev.pos().x() / size.width() * self.config.window_size[0]),
int(ev.pos().y() / size.height() * self.config.window_size[1])]
masks = np.array(self.parent().annotator.masks.masks)
mask_ids = np.where(masks[:, point[1], point[0]])[0]
print("Picking mask at point: {}".format(point))
if not(len(mask_ids)):
print("No mask found")
mask_id = -1
local_mask = np.zeros((masks.shape[1], masks.shape[2]))
label = None
else:
mask_id = self.mask_enum.pick(mask_ids)
local_mask = self.parent().annotator.masks.get_mask(mask_id)
label = self.parent().annotator.masks.get_label(mask_id + 1)
self.parent().annotator.masks.mask_id = mask_id
self.parent().annotator.last_mask = local_mask
self.parent().annotator.visualize_last_mask(label)
if self._paint_type == PaintType.BOX_PICKER and ev.button() == QtCore.Qt.LeftButton:
size = self.size()
point = [
float(ev.pos().x() / size.width()),
float(ev.pos().y() / size.height())]
bounding_box, bounding_box_id = self.parent().annotator.bounding_boxes.find_closest_bounding_box(point)
if bounding_box is None:
print("No bounding box found")
else:
self.parent().annotator.bounding_boxes.bounding_box_id = bounding_box_id
print(f"Bounding box: {bounding_box}")
print(f"Bounding box id: {bounding_box_id}")
self.parent().update(self.parent().annotator.merge_image_visualization())
if self._paint_type == PaintType.POINT:
point = ev.pos()
if ev.button() == QtCore.Qt.LeftButton:
self._move_update(None, point)
if ev.button() == QtCore.Qt.RightButton:
self._move_update(point, None)
self.update()
def zoom_to_rectangle(self, xstart, ystart, xend, yend):
picked_image = self.parent().annotator.image[ystart:yend, xstart:xend, :]
self.parent().annotator.image = cv2.resize(picked_image, (self.config.window_size[0], self.config.window_size[1]))
self.update()
def keyPressEvent(self, ev: PySide6.QtGui.QKeyEvent) -> None:
print(ev.key())
if self._paint_type == PaintType.MASK_PICKER and ev.key() == QtCore.Qt.Key.Key_D and len(self.parent().annotator.masks):
print("Deleting mask")
self.parent().annotator.masks.pop(self.parent().annotator.masks.mask_id)
self.parent().annotator.masks.mask_id = -1
self.parent().annotator.last_mask = None
self.parent().update(self.parent().annotator.merge_image_visualization())
def _get_scale(self):
return self.config.window_size[0] / self.size().width(), self.config.window_size[1] / self.size().height()
def get_annotations(
self,
temporary_point_positive: PySide6.QtCore.QPoint | None = None,
temporary_point_negative: PySide6.QtCore.QPoint | None = None
):
sx, sy = self._get_scale()
positive_points = [(
p.x() * sx,
p.y() * sy) for p in self.positive_points]
negative_points = [(
p.x() * sx,
p.y() * sy) for p in self.negative_points]
if temporary_point_positive:
positive_points += [(temporary_point_positive.x() * sx, temporary_point_positive.y() * sy)]
if temporary_point_negative:
negative_points += [(temporary_point_negative.x() * sx, temporary_point_negative.y() * sy)]
positive_points = np.array(positive_points).reshape(-1, 2)
negative_points = np.array(negative_points).reshape(-1, 2)
labels = np.array([1, ] * len(positive_points) + [0, ] * len(negative_points))
print(f"Positive points: {positive_points}")
print(f"Negative points: {negative_points}")
print(f"Labels: {labels}")
return {
"points": np.concatenate([positive_points, negative_points], axis=0),
"labels": labels,
"bounding_boxes": self.bounding_box.scale(sx, sy).to_numpy() if self.bounding_box else None
}
def clear(self):
self.positive_points = []
self.negative_points = []
self.bounding_box = None
self.partial_box = BoundingBox(0, 0, 0, 0)
self.polygon = Polygon()
self.update()

View File

@@ -0,0 +1,14 @@
from PySide6.QtGui import QImage, QPixmap, QPainter, QPen
from PySide6.QtCore import Qt
class ImagePixmap(QPixmap):
def __init__(self):
super().__init__()
def set_image(self, image):
if image.dtype == "uint8":
image = image.astype("float32") / 255.0
image = (image * 255).astype("uint8")
image = QImage(image.data, image.shape[1], image.shape[0], QImage.Format_RGB888)
self.convertFromImage(image)

View File

@@ -0,0 +1,81 @@
import logging
import sys
import cv2
import numpy as np
import torch
from PySide6.QtWidgets import (QApplication, QGridLayout, QLabel,
QMessageBox, QWidget)
from PySide6.QtCore import Qt
from segment_anything_ui.annotator import Annotator
from segment_anything_ui.annotation_layout import AnnotationLayout
from segment_anything_ui.config import Config
from segment_anything_ui.draw_label import DrawLabel
from segment_anything_ui.image_pixmap import ImagePixmap
from segment_anything_ui.model_builder import build_model
from segment_anything_ui.settings_layout import SettingsLayout
class SegmentAnythingUI(QWidget):
def __init__(self, config) -> None:
super().__init__()
self.config: Config = config
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.setWindowTitle("Segment Anything UI")
self.setWindowState(Qt.WindowState.WindowMaximized)
# self.setGeometry(100, 100, 800, 600)
self.layout = QGridLayout(self)
self.image_label = DrawLabel(self)
self.settings = SettingsLayout(self, config=self.config)
self.info_label = QLabel("Information about running process.")
self.sam = self.init_sam()
self.annotator = Annotator(sam=self.sam, parent=self)
self.annotation_layout = AnnotationLayout(self, config=self.config)
self.layout.addWidget(self.annotation_layout, 0, 0, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.image_label, 0, 1, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.settings, 0, 3, 1, 1, Qt.AlignCenter)
self.layout.addWidget(self.info_label, 1, 1, Qt.AlignBottom)
self.set_image(np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8))
self.show()
def set_image(self, image: np.ndarray, clear_annotations: bool = True):
self.annotator.set_image(image).make_embedding()
if clear_annotations:
self.annotator.clear()
self.update(image)
def update(self, image: np.ndarray):
image = cv2.resize(image, self.config.window_size)
pixmap = ImagePixmap()
pixmap.set_image(image)
print("Updating image")
self.image_label.setPixmap(pixmap)
def init_sam(self):
try:
checkpoint_path = str(self.settings.checkpoint_path.text())
sam = build_model(self.config.get_sam_model_name(), checkpoint_path, self.device)
except Exception as e:
logging.getLogger().exception(f"Error loading model: {e}")
QMessageBox.critical(self, "Error", "Could not load model")
return None
return sam
def get_mask(self):
return self.annotator.make_instance_mask()
def get_labels(self):
return self.annotator.make_labels()
def get_bounding_boxes(self):
return self.annotator.get_bounding_boxes()
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = SegmentAnythingUI(Config())
sys.exit(app.exec())

View File

@@ -0,0 +1,129 @@
import os
from PySide6.QtWidgets import QMessageBox
try:
from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator
IS_EFFICIENT_VIT_AVAILABLE = True
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
IS_EFFICIENT_VIT_AVAILABLE = False
try:
from segment_anything_hq import sam_model_registry as sam_hq_model_registry
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq
from segment_anything_hq.build_sam import Sam as SamHQ
IS_SAM_HQ_AVAILABLE = True
_SAM_HQ_MODEL_REGISTRY = {
"hq_vit_b": "vit_b",
"hq_vit_l": "vit_l",
"hq_vit_h": "vit_h",
"hq_vit_tiny": "vit_tiny",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
IS_SAM_HQ_AVAILABLE = False
_SAM_HQ_MODEL_REGISTRY = {}
try:
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
IS_SAM2_AVAILABLE = True
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize
# Reset Hydra's global configuration
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
_SAM2_MODEL_REGISTRY = {
"sam2.1_hiera_t": "sam2.1_hiera_t.yaml",
"sam2.1_hiera_l": "sam2.1_hiera_l.yaml",
"sam2.1_hiera_b": "sam2.1_hiera_b+.yaml",
"sam2.1_hiera_s": "sam2.1_hiera_s.yaml",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .")
IS_SAM2_AVAILABLE = False
_SAM2_MODEL_REGISTRY = {}
from segment_anything import sam_model_registry
from segment_anything import SamPredictor, automatic_mask_generator
from segment_anything.build_sam import Sam
def build_model(model_name: str, checkpoint_path: str, device: str):
match model_name:
case "xl0" | "xl1":
if not IS_EFFICIENT_VIT_AVAILABLE:
raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
efficientvit_sam = create_efficientvit_sam_model(
name=model_name, weight_url=checkpoint_path,
)
return efficientvit_sam.to(device).eval()
case "vit_b" | "vit_l" | "vit_h":
sam = sam_model_registry[model_name](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "hq_vit_b" | "hq_vit_l" | "hq_vit_h":
if not IS_SAM_HQ_AVAILABLE:
QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .")
raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s":
if not IS_SAM2_AVAILABLE:
QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .")
raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .")
with initialize(version_base=None, config_path="sam2_configs"):
sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device)
sam.eval()
return sam
case _:
raise ValueError(f"Model {model_name} not supported")
def get_predictor(sam):
if isinstance(sam, Sam):
return SamPredictor(sam)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamPredictor(sam)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return SamPredictorHQ(sam)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2ImagePredictor(sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")
def get_mask_generator(sam, **kwargs):
if isinstance(sam, Sam):
return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2AutomaticMaskGenerator(model=sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")

View File

View File

@@ -0,0 +1,45 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from segment_anything_ui.modeling.efficientvit.models.efficientvit import (
EfficientViTSam,
efficientvit_sam_l0,
efficientvit_sam_l1,
efficientvit_sam_l2,
)
from segment_anything_ui.modeling.efficientvit.models.nn.norm import set_norm_eps
from segment_anything_ui.modeling.efficientvit.models.utils import load_state_dict_from_file
__all__ = ["create_sam_model"]
REGISTERED_SAM_MODEL: dict[str, str] = {
"l0": "assets/checkpoints/sam/l0.pt",
"l1": "assets/checkpoints/sam/l1.pt",
"l2": "assets/checkpoints/sam/l2.pt",
}
def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam:
model_dict = {
"l0": efficientvit_sam_l0,
"l1": efficientvit_sam_l1,
"l2": efficientvit_sam_l2,
}
model_id = name.split("-")[0]
if model_id not in model_dict:
raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}")
else:
model = model_dict[model_id](**kwargs)
set_norm_eps(model, 1e-6)
if pretrained:
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
if weight_url is None:
raise ValueError(f"Do not find the pretrained weight of {name}.")
else:
weight = load_state_dict_from_file(weight_url)
model.load_state_dict(weight)
return model

View File

@@ -0,0 +1,30 @@
from safetensors import safe_open
from segment_anything.modeling import Sam
import torch.nn as nn
class ModifiedImageEncoder(nn.Module):
def __init__(self, image_encoder, saved_path: str | None = None):
super().__init__()
self.image_encoder = image_encoder
if saved_path is not None:
self.embeddings = safe_open(saved_path)
else:
self.embeddings = None
def forward(self, x):
return self.image_encoder(x) if self.embeddings is None else self.embeddings
class StorableSam:
def __init__(self, sam):
self.sam = sam
self.image_encoder = sam.image_encoder
def transform(self, saved_path):
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
def precompute(self, image):
return self.image_encoder(image)

View File

@@ -0,0 +1,116 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,120 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,119 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False

View File

@@ -0,0 +1,121 @@
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
# SAM decoder
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# HieraT does not currently support compilation, should always be set to False
compile_image_encoder: False

View File

@@ -0,0 +1,13 @@
import os
from typing import Any
import torch
class Saver:
def __init__(self, path: str) -> None:
self.path = path
def __call__(self, basename, mask_annotation) -> Any:
save_path = os.path.join(self.path, basename)
# TODO: finish this

View File

@@ -0,0 +1,228 @@
import json
import os
import pathlib
import random
import cv2
import numpy as np
from PySide6.QtWidgets import QPushButton, QWidget, QFileDialog, QVBoxLayout, QLineEdit, QLabel, QCheckBox, QMessageBox
from segment_anything_ui.annotator import MasksAnnotation
from segment_anything_ui.config import Config
from segment_anything_ui.utils.bounding_boxes import BoundingBox
class FilesHolder:
def __init__(self):
self.files = []
self.position = 0
def add_files(self, files):
self.files.extend(files)
def get_next(self):
self.position += 1
if self.position >= len(self.files):
self.position = 0
return self.files[self.position]
def get_previous(self):
self.position -= 1
if self.position < 0:
self.position = len(self.files) - 1
return self.files[self.position]
class SettingsLayout(QWidget):
MASK_EXTENSION = "_mask.png"
LABELS_EXTENSION = "_labels.json"
BOUNDING_BOXES_EXTENSION = "_bounding_boxes.json"
def __init__(self, parent, config: Config) -> None:
super().__init__(parent)
self.config = config
self.actual_file: str = ""
self.actual_shape = self.config.window_size
self.layout = QVBoxLayout(self)
self.open_files = QPushButton("Open Files")
self.open_files.clicked.connect(self.on_open_files)
self.next_file = QPushButton(f"Next File [ {config.key_mapping.NEXT_FILE.name} ]")
self.previous_file = QPushButton(f"Previous file [ {config.key_mapping.PREVIOUS_FILE.name} ]")
self.previous_file.setShortcut(config.key_mapping.PREVIOUS_FILE.key)
self.save_mask = QPushButton(f"Save Mask [ {config.key_mapping.SAVE_MASK.name} ]")
self.save_mask.clicked.connect(self.on_save_mask)
self.save_mask.setShortcut(config.key_mapping.SAVE_MASK.key)
self.save_bounding_boxes = QPushButton(f"Save Bounding Boxes [ {config.key_mapping.SAVE_BOUNDING_BOXES.name} ]")
self.save_bounding_boxes.clicked.connect(self.on_save_bounding_boxes)
self.save_bounding_boxes.setShortcut(config.key_mapping.SAVE_BOUNDING_BOXES.key)
self.next_file.clicked.connect(self.on_next_file)
self.next_file.setShortcut(config.key_mapping.NEXT_FILE.key)
self.previous_file.clicked.connect(self.on_previous_file)
self.checkpoint_path_label = QLabel(self, text="Checkpoint Path")
self.checkpoint_path = QLineEdit(self, text=self.parent().config.default_weights)
self.precompute_button = QPushButton("Precompute all embeddings")
self.precompute_button.clicked.connect(self.on_precompute)
self.delete_existing_annotation = QPushButton("Delete existing annotation")
self.delete_existing_annotation.clicked.connect(self.on_delete_existing_annotation)
self.show_image = QPushButton("Show Image")
self.show_visualization = QPushButton("Show Visualization")
self.show_bounding_boxes = QCheckBox("Show Bounding Boxes")
self.show_bounding_boxes.clicked.connect(self.on_show_bounding_boxes)
self.show_image.clicked.connect(self.on_show_image)
self.show_visualization.clicked.connect(self.on_show_visualization)
self.show_text = QCheckBox("Show Text")
self.show_text.clicked.connect(self.on_show_text)
self.tag_text_field = QLineEdit(self)
self.tag_text_field.setPlaceholderText("Comma separated image tags")
self.layout.addWidget(self.open_files)
self.layout.addWidget(self.next_file)
self.layout.addWidget(self.previous_file)
self.layout.addWidget(self.save_mask)
self.layout.addWidget(self.save_bounding_boxes)
self.layout.addWidget(self.delete_existing_annotation)
self.layout.addWidget(self.show_text)
self.layout.addWidget(self.show_bounding_boxes)
self.layout.addWidget(self.tag_text_field)
self.layout.addWidget(self.checkpoint_path_label)
self.layout.addWidget(self.checkpoint_path)
self.checkpoint_path.returnPressed.connect(self.on_checkpoint_path_changed)
self.checkpoint_path.editingFinished.connect(self.on_checkpoint_path_changed)
self.layout.addWidget(self.precompute_button)
self.layout.addWidget(self.show_image)
self.layout.addWidget(self.show_visualization)
self.files = FilesHolder()
self.original_image = np.zeros((self.config.window_size[1], self.config.window_size[0], 3), dtype=np.uint8)
def on_delete_existing_annotation(self):
path = os.path.split(self.actual_file)[0]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
if os.path.exists(mask_path):
os.remove(mask_path)
if os.path.exists(labels_path):
os.remove(labels_path)
def is_show_text(self):
return self.show_text.isChecked()
def on_show_text(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_next_file(self):
file = self.files.get_next()
self._load_image(file)
def on_previous_file(self):
file = self.files.get_previous()
self._load_image(file)
def _load_image(self, file: str):
mask = file.split(".")[0] + self.MASK_EXTENSION
labels = file.split(".")[0] + self.LABELS_EXTENSION
bounding_boxes = file.split(".")[0] + self.BOUNDING_BOXES_EXTENSION
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
self.actual_shape = image.shape[:2][::-1]
self.actual_file = file
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.dtype in [np.float32, np.float64, np.uint16]:
image = (image / np.amax(image) * 255).astype("uint8")
#image = np.expand_dims(image[:, :, 2], axis=-1).repeat(3, axis=-1)
image = cv2.resize(image,
(int(self.parent().config.window_size[0]), self.parent().config.window_size[1]))
self.parent().annotator.clear()
self.parent().image_label.clear()
self.original_image = image.copy()
self.parent().set_image(image)
if os.path.exists(mask) and os.path.exists(labels):
self._load_annotation(mask, labels)
self.parent().info_label.setText(f"Loaded annotation from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
elif os.path.exists(bounding_boxes):
self._load_bounding_boxes(bounding_boxes)
self.parent().info_label.setText(f"Loaded bounding boxes from saved files! Image: {file}")
self.parent().update(self.parent().annotator.merge_image_visualization())
else:
self.parent().info_label.setText(f"No annotation found! Image: {file}")
self.tag_text_field.setText("")
def _load_annotation(self, mask, labels):
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
mask = cv2.resize(mask, (self.config.window_size[0], self.config.window_size[1]),
interpolation=cv2.INTER_NEAREST)
with open(labels, "r") as fp:
labels: dict[str, str] = json.load(fp)
masks = []
new_labels = []
if "instances" in labels:
instance_labels = labels["instances"]
else:
instance_labels = labels
if "tags" in labels:
self.tag_text_field.setText(",".join(labels["tags"]))
else:
self.tag_text_field.setText("")
for str_index, class_ in instance_labels.items():
single_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
single_mask[mask == int(str_index)] = 255
masks.append(single_mask)
new_labels.append(class_)
self.parent().annotator.masks = MasksAnnotation.from_masks(masks, new_labels)
def _load_bounding_boxes(self, bounding_boxes):
with open(bounding_boxes, "r") as f:
bounding_boxes: list[dict[str, float | str]] = json.load(f)
for bounding_box in bounding_boxes:
self.parent().annotator.bounding_boxes.append(BoundingBox(**bounding_box))
def on_show_image(self):
self.parent().set_image(self.original_image, clear_annotations=False)
def on_show_visualization(self):
self.parent().update(self.parent().annotator.merge_image_visualization())
def on_precompute(self):
pass
def on_save_mask(self):
path = os.path.split(self.actual_file)[0]
tags = self.tag_text_field.text().split(",")
tags = [tag.strip() for tag in tags]
basename = os.path.splitext(os.path.basename(self.actual_file))[0]
mask_path = os.path.join(path, basename + self.MASK_EXTENSION)
labels_path = os.path.join(path, basename + self.LABELS_EXTENSION)
masks = self.parent().get_mask()
labels = {"instances": self.parent().get_labels(), "tags": tags}
with open(labels_path, "w") as f:
json.dump(labels, f, indent=4)
masks = cv2.resize(masks, self.actual_shape, interpolation=cv2.INTER_NEAREST)
cv2.imwrite(mask_path, masks)
def on_checkpoint_path_changed(self):
self.parent().sam = self.parent().init_sam()
def on_open_files(self):
files, _ = QFileDialog.getOpenFileNames(self, "Open Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)")
random.shuffle(files)
self.files.add_files(files)
self.on_next_file()
def on_save_bounding_boxes(self):
path = os.path.split(self.actual_file)[0]
basename = pathlib.Path(self.actual_file).stem
bounding_boxes_path = os.path.join(path, basename + self.BOUNDING_BOXES_EXTENSION)
bounding_boxes = self.parent().get_bounding_boxes()
bounding_boxes_dict = [bounding_box.to_dict() for bounding_box in bounding_boxes]
with open(bounding_boxes_path, "w") as f:
json.dump(bounding_boxes_dict, f, indent=4)
def is_show_bounding_boxes(self):
return self.show_bounding_boxes.isChecked()
def on_show_bounding_boxes(self):
self.parent().update(self.parent().annotator.merge_image_visualization())

View File

View File

@@ -0,0 +1,55 @@
import dataclasses
import numpy as np
@dataclasses.dataclass
class BoundingBox:
x_min: float
y_min: float
x_max: float
y_max: float
label: str
mask_uid: str = ""
def to_dict(self):
return {
"x_min": self.x_min,
"y_min": self.y_min,
"x_max": self.x_max,
"y_max": self.y_max,
"label": self.label,
"mask_uid": self.mask_uid
}
@property
def center(self):
return np.array([(self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2])
def distance_to(self, point: np.ndarray):
return np.linalg.norm(self.center - point)
def contains(self, point: np.ndarray):
return self.x_min <= point[0] <= self.x_max and self.y_min <= point[1] <= self.y_max
def get_mask_bounding_box(mask, label: str):
where = np.where(mask)
x_min = np.min(where[1])
y_min = np.min(where[0])
x_max = np.max(where[1])
y_max = np.max(where[0])
return BoundingBox(
x_min / mask.shape[1],
y_min / mask.shape[0],
x_max / mask.shape[1],
y_max / mask.shape[0],
label
)
def get_bounding_boxes(masks, labels):
bounding_boxes = []
for mask, label in zip(masks, labels):
bounding_box = get_mask_bounding_box(mask, label)
bounding_boxes.append(bounding_box)
return bounding_boxes

View File

@@ -0,0 +1,26 @@
import glob
import os
import cv2
import numpy as np
import torch
import rich
from PIL import Image
import safetensors
from segment_anything import sam_model_registry
from segment_anything_ui.modeling.storable_sam import StorableSam
from segment_anything_ui.config import Config
config = Config()
sam = sam_model_registry[config.get_sam_model_name()](checkpoint=config.default_weights)
allowed_extensions = [".jpg", ".png", ".tif", ".tiff"]
def load_images_from_folder(folder_path):
images = []
for filename in os.listdir(folder_path):
allowed_extensions = [".jpg", ".png"]
if any(filename.endswith(ext) for ext in allowed_extensions):
img = Image.open(os.path.join(folder_path, filename))
return images

View File

@@ -0,0 +1,53 @@
import dataclasses
import cv2
import numpy as np
from PySide6.QtCore import QPoint
from PySide6.QtGui import QPolygon
@dataclasses.dataclass
class BoundingBox:
xstart: float | int
ystart: float | int
xend: float | int = -1.
yend: float | int = -1.
def to_numpy(self):
return np.array([self.xstart, self.ystart, self.xend, self.yend])
def scale(self, sx, sy):
return BoundingBox(
xstart=self.xstart * sx,
ystart=self.ystart * sy,
xend=self.xend * sx,
yend=self.yend * sy
)
def to_int(self):
return BoundingBox(
xstart=int(self.xstart),
ystart=int(self.ystart),
xend=int(self.xend),
yend=int(self.yend)
)
@dataclasses.dataclass
class Polygon:
points: list = dataclasses.field(default_factory=list)
def to_numpy(self):
return np.array(self.points).reshape(-1, 2)
def to_mask(self, num_rows, num_cols):
mask = np.zeros((num_rows, num_cols))
mask = cv2.fillPoly(mask, pts=[self.to_numpy(), ], color=255)
return mask
def is_plotable(self):
return len(self.points) > 3
def to_qpolygon(self):
return QPolygon([
QPoint(x, y) for x, y in self.points
])

View File

0
setup.py Normal file
View File