initial_tune

This commit is contained in:
AI-team\cyhan
2025-05-12 11:23:49 +09:00
commit b436a81091
33 changed files with 2398 additions and 0 deletions

View File

View File

@@ -0,0 +1,45 @@
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from segment_anything_ui.modeling.efficientvit.models.efficientvit import (
EfficientViTSam,
efficientvit_sam_l0,
efficientvit_sam_l1,
efficientvit_sam_l2,
)
from segment_anything_ui.modeling.efficientvit.models.nn.norm import set_norm_eps
from segment_anything_ui.modeling.efficientvit.models.utils import load_state_dict_from_file
__all__ = ["create_sam_model"]
REGISTERED_SAM_MODEL: dict[str, str] = {
"l0": "assets/checkpoints/sam/l0.pt",
"l1": "assets/checkpoints/sam/l1.pt",
"l2": "assets/checkpoints/sam/l2.pt",
}
def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam:
model_dict = {
"l0": efficientvit_sam_l0,
"l1": efficientvit_sam_l1,
"l2": efficientvit_sam_l2,
}
model_id = name.split("-")[0]
if model_id not in model_dict:
raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}")
else:
model = model_dict[model_id](**kwargs)
set_norm_eps(model, 1e-6)
if pretrained:
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
if weight_url is None:
raise ValueError(f"Do not find the pretrained weight of {name}.")
else:
weight = load_state_dict_from_file(weight_url)
model.load_state_dict(weight)
return model

View File

@@ -0,0 +1,30 @@
from safetensors import safe_open
from segment_anything.modeling import Sam
import torch.nn as nn
class ModifiedImageEncoder(nn.Module):
def __init__(self, image_encoder, saved_path: str | None = None):
super().__init__()
self.image_encoder = image_encoder
if saved_path is not None:
self.embeddings = safe_open(saved_path)
else:
self.embeddings = None
def forward(self, x):
return self.image_encoder(x) if self.embeddings is None else self.embeddings
class StorableSam:
def __init__(self, sam):
self.sam = sam
self.image_encoder = sam.image_encoder
def transform(self, saved_path):
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
def precompute(self, image):
return self.image_encoder(image)