31 lines
844 B
Python
31 lines
844 B
Python
from safetensors import safe_open
|
|
from segment_anything.modeling import Sam
|
|
import torch.nn as nn
|
|
|
|
|
|
class ModifiedImageEncoder(nn.Module):
|
|
|
|
def __init__(self, image_encoder, saved_path: str | None = None):
|
|
super().__init__()
|
|
self.image_encoder = image_encoder
|
|
if saved_path is not None:
|
|
self.embeddings = safe_open(saved_path)
|
|
else:
|
|
self.embeddings = None
|
|
|
|
def forward(self, x):
|
|
return self.image_encoder(x) if self.embeddings is None else self.embeddings
|
|
|
|
|
|
class StorableSam:
|
|
|
|
def __init__(self, sam):
|
|
self.sam = sam
|
|
self.image_encoder = sam.image_encoder
|
|
|
|
def transform(self, saved_path):
|
|
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
|
|
|
|
def precompute(self, image):
|
|
return self.image_encoder(image)
|