Files
segment-anything-ui-gpu/segment_anything_ui/modeling/storable_sam.py
AI-team\cyhan b436a81091 initial_tune
2025-05-12 11:23:49 +09:00

31 lines
844 B
Python

from safetensors import safe_open
from segment_anything.modeling import Sam
import torch.nn as nn
class ModifiedImageEncoder(nn.Module):
def __init__(self, image_encoder, saved_path: str | None = None):
super().__init__()
self.image_encoder = image_encoder
if saved_path is not None:
self.embeddings = safe_open(saved_path)
else:
self.embeddings = None
def forward(self, x):
return self.image_encoder(x) if self.embeddings is None else self.embeddings
class StorableSam:
def __init__(self, sam):
self.sam = sam
self.image_encoder = sam.image_encoder
def transform(self, saved_path):
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
def precompute(self, image):
return self.image_encoder(image)