129 lines
5.5 KiB
Python
129 lines
5.5 KiB
Python
import os
|
|
from PySide6.QtWidgets import QMessageBox
|
|
|
|
try:
|
|
from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam
|
|
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator
|
|
IS_EFFICIENT_VIT_AVAILABLE = True
|
|
except (ModuleNotFoundError, ImportError) as e:
|
|
import logging
|
|
logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
|
|
IS_EFFICIENT_VIT_AVAILABLE = False
|
|
|
|
try:
|
|
from segment_anything_hq import sam_model_registry as sam_hq_model_registry
|
|
from segment_anything_hq import SamPredictor as SamPredictorHQ
|
|
from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq
|
|
from segment_anything_hq.build_sam import Sam as SamHQ
|
|
IS_SAM_HQ_AVAILABLE = True
|
|
_SAM_HQ_MODEL_REGISTRY = {
|
|
"hq_vit_b": "vit_b",
|
|
"hq_vit_l": "vit_l",
|
|
"hq_vit_h": "vit_h",
|
|
"hq_vit_tiny": "vit_tiny",
|
|
}
|
|
except (ModuleNotFoundError, ImportError) as e:
|
|
import logging
|
|
logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
|
|
IS_SAM_HQ_AVAILABLE = False
|
|
_SAM_HQ_MODEL_REGISTRY = {}
|
|
|
|
try:
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.modeling.sam2_base import SAM2Base
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
IS_SAM2_AVAILABLE = True
|
|
from hydra.core.global_hydra import GlobalHydra
|
|
from hydra import initialize
|
|
|
|
# Reset Hydra's global configuration
|
|
if GlobalHydra.instance().is_initialized():
|
|
GlobalHydra.instance().clear()
|
|
|
|
_SAM2_MODEL_REGISTRY = {
|
|
"sam2.1_hiera_t": "sam2.1_hiera_t.yaml",
|
|
"sam2.1_hiera_l": "sam2.1_hiera_l.yaml",
|
|
"sam2.1_hiera_b": "sam2.1_hiera_b+.yaml",
|
|
"sam2.1_hiera_s": "sam2.1_hiera_s.yaml",
|
|
}
|
|
|
|
except (ModuleNotFoundError, ImportError) as e:
|
|
import logging
|
|
logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .")
|
|
IS_SAM2_AVAILABLE = False
|
|
_SAM2_MODEL_REGISTRY = {}
|
|
|
|
from segment_anything import sam_model_registry
|
|
from segment_anything import SamPredictor, automatic_mask_generator
|
|
from segment_anything.build_sam import Sam
|
|
|
|
|
|
def build_model(model_name: str, checkpoint_path: str, device: str):
|
|
match model_name:
|
|
case "xl0" | "xl1":
|
|
if not IS_EFFICIENT_VIT_AVAILABLE:
|
|
raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
|
|
efficientvit_sam = create_efficientvit_sam_model(
|
|
name=model_name, weight_url=checkpoint_path,
|
|
)
|
|
return efficientvit_sam.to(device).eval()
|
|
|
|
case "vit_b" | "vit_l" | "vit_h":
|
|
sam = sam_model_registry[model_name](
|
|
checkpoint=checkpoint_path)
|
|
sam.eval()
|
|
return sam.to(device)
|
|
|
|
case "hq_vit_b" | "hq_vit_l" | "hq_vit_h":
|
|
if not IS_SAM_HQ_AVAILABLE:
|
|
QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .")
|
|
raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
|
|
sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]](
|
|
checkpoint=checkpoint_path)
|
|
sam.eval()
|
|
return sam.to(device)
|
|
case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s":
|
|
if not IS_SAM2_AVAILABLE:
|
|
QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .")
|
|
raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .")
|
|
|
|
with initialize(version_base=None, config_path="sam2_configs"):
|
|
sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device)
|
|
sam.eval()
|
|
return sam
|
|
case _:
|
|
raise ValueError(f"Model {model_name} not supported")
|
|
|
|
|
|
def get_predictor(sam):
|
|
if isinstance(sam, Sam):
|
|
return SamPredictor(sam)
|
|
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
|
|
return EfficientViTSamPredictor(sam)
|
|
|
|
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
|
|
return SamPredictorHQ(sam)
|
|
|
|
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
|
|
return SAM2ImagePredictor(sam)
|
|
else:
|
|
raise ValueError("Model is not an EfficientViTSam or Sam")
|
|
|
|
|
|
def get_mask_generator(sam, **kwargs):
|
|
if isinstance(sam, Sam):
|
|
return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs)
|
|
|
|
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
|
|
return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs)
|
|
|
|
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
|
|
return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs)
|
|
|
|
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
|
|
return SAM2AutomaticMaskGenerator(model=sam)
|
|
|
|
else:
|
|
raise ValueError("Model is not an EfficientViTSam or Sam") |