initial_tune
This commit is contained in:
30
segment_anything_ui/modeling/storable_sam.py
Normal file
30
segment_anything_ui/modeling/storable_sam.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from safetensors import safe_open
|
||||
from segment_anything.modeling import Sam
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModifiedImageEncoder(nn.Module):
|
||||
|
||||
def __init__(self, image_encoder, saved_path: str | None = None):
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
if saved_path is not None:
|
||||
self.embeddings = safe_open(saved_path)
|
||||
else:
|
||||
self.embeddings = None
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_encoder(x) if self.embeddings is None else self.embeddings
|
||||
|
||||
|
||||
class StorableSam:
|
||||
|
||||
def __init__(self, sam):
|
||||
self.sam = sam
|
||||
self.image_encoder = sam.image_encoder
|
||||
|
||||
def transform(self, saved_path):
|
||||
self.image_encoder = ModifiedImageEncoder(self.image_encoder, saved_path)
|
||||
|
||||
def precompute(self, image):
|
||||
return self.image_encoder(image)
|
||||
Reference in New Issue
Block a user