Files
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

144 lines
6.0 KiB
Python

import dataclasses
import os
from typing import Literal
from PySide6.QtCore import Qt
import requests
try:
from tqdm import tqdm
import wget
except ImportError:
print("Tqdm and wget not found. Install with pip install tqdm wget")
tqdm = None
wget = None
@dataclasses.dataclass(frozen=True)
class Keymap:
key: Qt.Key | str
name: str
class ProgressBar:
def __init__(self):
self.progress_bar = None
def __call__(self, current_bytes, total_bytes, width):
current_mb = round(current_bytes / 1024 ** 2, 1)
total_mb = round(total_bytes / 1024 ** 2, 1)
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_mb, desc="MB")
delta_mb = current_mb - self.progress_bar.n
self.progress_bar.update(delta_mb)
@dataclasses.dataclass
class KeyBindings:
ADD_POINT: Keymap = Keymap(Qt.Key.Key_W, "W")
ADD_BOX: Keymap = Keymap(Qt.Key.Key_Q, "Q")
ANNOTATE_ALL: Keymap = Keymap(Qt.Key.Key_Return, "Enter")
MANUAL_POLYGON: Keymap = Keymap(Qt.Key.Key_R, "R")
CANCEL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_C, "C")
SAVE_ANNOTATION: Keymap = Keymap(Qt.Key.Key_S, "S")
PICK_MASK: Keymap = Keymap(Qt.Key.Key_X, "X")
PICK_BOUNDING_BOX: Keymap = Keymap(Qt.Key.Key_B, "B")
MERGE_MASK: Keymap = Keymap(Qt.Key.Key_Z, "Z")
DELETE_MASK: Keymap = Keymap(Qt.Key.Key_V, "V")
PARTIAL_ANNOTATION: Keymap = Keymap(Qt.Key.Key_D, "D")
SAVE_BOUNDING_BOXES: Keymap = Keymap("Ctrl+B", "Ctrl+B")
NEXT_FILE: Keymap = Keymap(Qt.Key.Key_F, "F")
PREVIOUS_FILE: Keymap = Keymap(Qt.Key.Key_G, "G")
SAVE_MASK: Keymap = Keymap("Ctrl+S", "Ctrl+S")
PRECOMPUTE: Keymap = Keymap(Qt.Key.Key_P, "P")
ZOOM_RECTANGLE: Keymap = Keymap(Qt.Key.Key_E, "E")
@dataclasses.dataclass
class Config:
default_weights: Literal[
"sam_vit_b_01ec64.pth",
"sam_vit_h_4b8939.pth",
"sam_vit_l_0b3195.pth",
"xl0.pt",
"xl1.pt",
"sam_hq_vit_b.pth",
"sam_hq_vit_l.pth",
"sam_hq_vit_h.pth",
"sam_hq_vit_tiny.pth",
"sam2.1_hiera_t.pth",
"sam2.1_hiera_l.pth",
"sam2.1_hiera_b+.pth",
"sam2.1_hiera_s.pth",
] = "sam_vit_h_4b8939.pth"
download_weights_if_not_available: bool = True
label_file: str = "labels.json"
window_size: tuple[int, int] | int = (1920, 1080)
key_mapping: KeyBindings = dataclasses.field(default_factory=KeyBindings)
weights_paths: dict[str, str] = dataclasses.field(default_factory=lambda: {
"l2": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/l2.pt",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"xl0": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl0.pt",
"xl1": "https://huggingface.co/han-cai/efficientvit-sam/resolve/main/xl1.pt",
"hq_vit_b": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
"hq_vit_l": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
"hq_vit_h": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"hq_vit_tiny": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
"sam2.1_hiera_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"sam2.1_hiera_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"sam2.1_hiera_b+": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"sam2.1_hiera_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
})
def __post_init__(self):
if isinstance(self.window_size, int):
self.window_size = (self.window_size, self.window_size)
if self.download_weights_if_not_available:
self.download_weights()
def get_sam_model_name(self):
if "l2" in self.default_weights:
return "l2"
if "sam_vit_b" in self.default_weights:
return "vit_b"
if "sam_vit_h" in self.default_weights:
return "vit_h"
if "sam_vit_l" in self.default_weights:
return "vit_l"
if "xl0" in self.default_weights:
return "xl0"
if "xl1" in self.default_weights:
return "xl1"
if "hq_vit_b" in self.default_weights:
return "hq_vit_b"
if "hq_vit_l" in self.default_weights:
return "hq_vit_l"
if "hq_vit_h" in self.default_weights:
return "hq_vit_h"
if "hq_vit_tiny" in self.default_weights:
return "hq_vit_tiny"
if "sam2.1_hiera_t" in self.default_weights:
return "sam2.1_hiera_t"
if "sam2.1_hiera_l" in self.default_weights:
return "sam2.1_hiera_l"
if "sam2.1_hiera_b+" in self.default_weights:
return "sam2.1_hiera_b+"
if "sam2.1_hiera_s" in self.default_weights:
return "sam2.1_hiera_s"
raise ValueError("Unknown model name")
def download_weights(self):
if not os.path.exists(self.default_weights):
try:
print(f"Downloading weights for model {self.get_sam_model_name()}")
wget.download(self.weights_paths[self.get_sam_model_name()], self.default_weights, bar=ProgressBar())
print(f"Downloaded weights to {self.default_weights}")
except Exception as e:
print(f"Error downloading weights: {e}. Trying with requests.")
model_name = self.get_sam_model_name()
print(f"Downloading weights for model {model_name}")
file = requests.get(self.weights_paths[model_name])
with open(self.default_weights, "wb") as f:
f.write(file.content)
print(f"Downloaded weights to {self.default_weights}")