Files
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

129 lines
5.5 KiB
Python

import os
from PySide6.QtWidgets import QMessageBox
try:
from efficientvit.sam_model_zoo import create_efficientvit_sam_model, EfficientViTSam
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, EfficientViTSamAutomaticMaskGenerator
IS_EFFICIENT_VIT_AVAILABLE = True
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Efficient is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
IS_EFFICIENT_VIT_AVAILABLE = False
try:
from segment_anything_hq import sam_model_registry as sam_hq_model_registry
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import automatic_mask_generator as automatic_mask_generator_hq
from segment_anything_hq.build_sam import Sam as SamHQ
IS_SAM_HQ_AVAILABLE = True
_SAM_HQ_MODEL_REGISTRY = {
"hq_vit_b": "vit_b",
"hq_vit_l": "vit_l",
"hq_vit_h": "vit_h",
"hq_vit_tiny": "vit_tiny",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
IS_SAM_HQ_AVAILABLE = False
_SAM_HQ_MODEL_REGISTRY = {}
try:
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
IS_SAM2_AVAILABLE = True
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize
# Reset Hydra's global configuration
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
_SAM2_MODEL_REGISTRY = {
"sam2.1_hiera_t": "sam2.1_hiera_t.yaml",
"sam2.1_hiera_l": "sam2.1_hiera_l.yaml",
"sam2.1_hiera_b": "sam2.1_hiera_b+.yaml",
"sam2.1_hiera_s": "sam2.1_hiera_s.yaml",
}
except (ModuleNotFoundError, ImportError) as e:
import logging
logging.warning("SAM2 is not available, please install the package from https://github.com/SysCV/sam2 .")
IS_SAM2_AVAILABLE = False
_SAM2_MODEL_REGISTRY = {}
from segment_anything import sam_model_registry
from segment_anything import SamPredictor, automatic_mask_generator
from segment_anything.build_sam import Sam
def build_model(model_name: str, checkpoint_path: str, device: str):
match model_name:
case "xl0" | "xl1":
if not IS_EFFICIENT_VIT_AVAILABLE:
raise ValueError("EfficientViTSam is not available, please install the package from https://github.com/mit-han-lab/efficientvit/tree/master .")
efficientvit_sam = create_efficientvit_sam_model(
name=model_name, weight_url=checkpoint_path,
)
return efficientvit_sam.to(device).eval()
case "vit_b" | "vit_l" | "vit_h":
sam = sam_model_registry[model_name](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "hq_vit_b" | "hq_vit_l" | "hq_vit_h":
if not IS_SAM_HQ_AVAILABLE:
QMessageBox.critical(None, "Segment Anything HQ is not available", "Please install the package from http://github.com/SysCV/sam-hq .")
raise ValueError("Segment Anything HQ is not available, please install the package from http://github.com/SysCV/sam-hq .")
sam = sam_hq_model_registry[_SAM_HQ_MODEL_REGISTRY[model_name]](
checkpoint=checkpoint_path)
sam.eval()
return sam.to(device)
case "sam2.1_hiera_t" | "sam2.1_hiera_l" | "sam2.1_hiera_b" | "sam2.1_hiera_s":
if not IS_SAM2_AVAILABLE:
QMessageBox.critical(None, "SAM2 is not available", "Please install the package from https://github.com/facebookresearch/sam2 .")
raise ValueError("SAM2 is not available, please install the package from https://github.com/facebookresearch/sam2 .")
with initialize(version_base=None, config_path="sam2_configs"):
sam = build_sam2(_SAM2_MODEL_REGISTRY[model_name], checkpoint_path, device=device)
sam.eval()
return sam
case _:
raise ValueError(f"Model {model_name} not supported")
def get_predictor(sam):
if isinstance(sam, Sam):
return SamPredictor(sam)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamPredictor(sam)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return SamPredictorHQ(sam)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2ImagePredictor(sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")
def get_mask_generator(sam, **kwargs):
if isinstance(sam, Sam):
return automatic_mask_generator.SamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM_HQ_AVAILABLE and isinstance(sam, SamHQ):
return automatic_mask_generator_hq.SamAutomaticMaskGeneratorHQ(model=sam, **kwargs)
elif IS_EFFICIENT_VIT_AVAILABLE and isinstance(sam, EfficientViTSam):
return EfficientViTSamAutomaticMaskGenerator(model=sam, **kwargs)
elif IS_SAM2_AVAILABLE and isinstance(sam, SAM2Base):
return SAM2AutomaticMaskGenerator(model=sam)
else:
raise ValueError("Model is not an EfficientViTSam or Sam")