initial_tune
This commit is contained in:
0
segment_anything_ui/utils/__init__.py
Normal file
0
segment_anything_ui/utils/__init__.py
Normal file
55
segment_anything_ui/utils/bounding_boxes.py
Normal file
55
segment_anything_ui/utils/bounding_boxes.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BoundingBox:
|
||||
x_min: float
|
||||
y_min: float
|
||||
x_max: float
|
||||
y_max: float
|
||||
label: str
|
||||
mask_uid: str = ""
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"x_min": self.x_min,
|
||||
"y_min": self.y_min,
|
||||
"x_max": self.x_max,
|
||||
"y_max": self.y_max,
|
||||
"label": self.label,
|
||||
"mask_uid": self.mask_uid
|
||||
}
|
||||
|
||||
@property
|
||||
def center(self):
|
||||
return np.array([(self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2])
|
||||
|
||||
def distance_to(self, point: np.ndarray):
|
||||
return np.linalg.norm(self.center - point)
|
||||
|
||||
def contains(self, point: np.ndarray):
|
||||
return self.x_min <= point[0] <= self.x_max and self.y_min <= point[1] <= self.y_max
|
||||
|
||||
|
||||
def get_mask_bounding_box(mask, label: str):
|
||||
where = np.where(mask)
|
||||
x_min = np.min(where[1])
|
||||
y_min = np.min(where[0])
|
||||
x_max = np.max(where[1])
|
||||
y_max = np.max(where[0])
|
||||
return BoundingBox(
|
||||
x_min / mask.shape[1],
|
||||
y_min / mask.shape[0],
|
||||
x_max / mask.shape[1],
|
||||
y_max / mask.shape[0],
|
||||
label
|
||||
)
|
||||
|
||||
|
||||
def get_bounding_boxes(masks, labels):
|
||||
bounding_boxes = []
|
||||
for mask, label in zip(masks, labels):
|
||||
bounding_box = get_mask_bounding_box(mask, label)
|
||||
bounding_boxes.append(bounding_box)
|
||||
return bounding_boxes
|
||||
26
segment_anything_ui/utils/precompute_folder.py
Normal file
26
segment_anything_ui/utils/precompute_folder.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import rich
|
||||
from PIL import Image
|
||||
import safetensors
|
||||
from segment_anything import sam_model_registry
|
||||
|
||||
from segment_anything_ui.modeling.storable_sam import StorableSam
|
||||
from segment_anything_ui.config import Config
|
||||
|
||||
config = Config()
|
||||
sam = sam_model_registry[config.get_sam_model_name()](checkpoint=config.default_weights)
|
||||
allowed_extensions = [".jpg", ".png", ".tif", ".tiff"]
|
||||
|
||||
|
||||
def load_images_from_folder(folder_path):
|
||||
images = []
|
||||
for filename in os.listdir(folder_path):
|
||||
allowed_extensions = [".jpg", ".png"]
|
||||
if any(filename.endswith(ext) for ext in allowed_extensions):
|
||||
img = Image.open(os.path.join(folder_path, filename))
|
||||
return images
|
||||
53
segment_anything_ui/utils/shapes.py
Normal file
53
segment_anything_ui/utils/shapes.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import dataclasses
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PySide6.QtCore import QPoint
|
||||
from PySide6.QtGui import QPolygon
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BoundingBox:
|
||||
xstart: float | int
|
||||
ystart: float | int
|
||||
xend: float | int = -1.
|
||||
yend: float | int = -1.
|
||||
|
||||
def to_numpy(self):
|
||||
return np.array([self.xstart, self.ystart, self.xend, self.yend])
|
||||
|
||||
def scale(self, sx, sy):
|
||||
return BoundingBox(
|
||||
xstart=self.xstart * sx,
|
||||
ystart=self.ystart * sy,
|
||||
xend=self.xend * sx,
|
||||
yend=self.yend * sy
|
||||
)
|
||||
|
||||
def to_int(self):
|
||||
return BoundingBox(
|
||||
xstart=int(self.xstart),
|
||||
ystart=int(self.ystart),
|
||||
xend=int(self.xend),
|
||||
yend=int(self.yend)
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Polygon:
|
||||
points: list = dataclasses.field(default_factory=list)
|
||||
|
||||
def to_numpy(self):
|
||||
return np.array(self.points).reshape(-1, 2)
|
||||
|
||||
def to_mask(self, num_rows, num_cols):
|
||||
mask = np.zeros((num_rows, num_cols))
|
||||
mask = cv2.fillPoly(mask, pts=[self.to_numpy(), ], color=255)
|
||||
return mask
|
||||
|
||||
def is_plotable(self):
|
||||
return len(self.points) > 3
|
||||
|
||||
def to_qpolygon(self):
|
||||
return QPolygon([
|
||||
QPoint(x, y) for x, y in self.points
|
||||
])
|
||||
0
segment_anything_ui/utils/tooltips.py
Normal file
0
segment_anything_ui/utils/tooltips.py
Normal file
Reference in New Issue
Block a user