First commit
This commit is contained in:
174
services/deepencoder/build_linear.py
Normal file
174
services/deepencoder/build_linear.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
|
||||
|
||||
class MlpProjector(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.projector_type == "identity":
|
||||
modules = nn.Identity()
|
||||
|
||||
elif cfg.projector_type == "linear":
|
||||
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
||||
|
||||
elif cfg.projector_type == "mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
mlp_ratio = cfg.get("mlp_ratio", 1)
|
||||
modules = [
|
||||
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
|
||||
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
|
||||
]
|
||||
for _ in range(1, mlp_depth - 1):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "downsample_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
mlp_ratio = cfg.get("mlp_ratio", 1)
|
||||
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
|
||||
for _ in range(1, mlp_depth - 1):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||
|
||||
modules = []
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
channel_div = cfg.get("channel_div", 0.5)
|
||||
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
|
||||
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
|
||||
|
||||
modules = []
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "low_high_split_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
modules = []
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
|
||||
modules = nn.Sequential(*modules)
|
||||
self.high_layers = nn.Sequential(*modules)
|
||||
self.low_layers = copy.deepcopy(modules)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
||||
|
||||
if cfg.get("token_pooling", False):
|
||||
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
|
||||
|
||||
if cfg.get("conv_fusion_high_low_features", False):
|
||||
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
|
||||
self.layers = modules
|
||||
|
||||
def forward(self, x):
|
||||
if self.cfg.get("token_pooling", False):
|
||||
batch_size, wxh, channels = x.shape
|
||||
w = h = int(wxh**0.5)
|
||||
x = x.view(batch_size, w, h, channels)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
||||
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
||||
# 在通道维度上拼接
|
||||
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
|
||||
|
||||
# 通过线性层
|
||||
patches = patches.permute(0, 2, 1, 3).contiguous()
|
||||
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
||||
|
||||
x = self.token_pooling_layer(patches)
|
||||
|
||||
if self.cfg.get("conv_fusion_high_low_features", False):
|
||||
x = self.fusion_layer(x[:, 0]) + x[:, 1]
|
||||
|
||||
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
|
||||
high_x, low_x = x[0], x[1]
|
||||
high_x = self.high_up_proj(high_x)
|
||||
low_x = self.low_up_proj(low_x)
|
||||
x = torch.concat([high_x, low_x], dim=-1)
|
||||
|
||||
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
|
||||
high_x = x[...,:self.cfg.input_dim[0]]
|
||||
low_x = x[...,self.cfg.input_dim[0]:]
|
||||
high_x = self.high_up_proj(high_x)
|
||||
low_x = self.low_up_proj(low_x)
|
||||
x = torch.concat([high_x, low_x], dim=-1)
|
||||
|
||||
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
|
||||
high_x, low_x = x[0], x[1]
|
||||
high_x = self.high_layers(high_x)
|
||||
low_x = self.low_layers(low_x)
|
||||
x = torch.concat([high_x, low_x], dim=-1)
|
||||
return x
|
||||
|
||||
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
|
||||
bs, hw, input_dim = x.shape
|
||||
h = w = int((hw) ** 0.5)
|
||||
|
||||
"""compute padding"""
|
||||
if h % self.cfg.downsample_ratio:
|
||||
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
||||
else:
|
||||
pad = 0
|
||||
x = x.reshape(bs, h, w, input_dim)
|
||||
if pad > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||||
|
||||
"""4 to 1 concat"""
|
||||
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||||
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
return self.layers(x)
|
||||
|
||||
@staticmethod
|
||||
def get_flops_per_sample(cfg):
|
||||
if cfg.projector_type == "linear":
|
||||
fwd = 2 * cfg.input_dim * cfg.n_embed
|
||||
|
||||
elif "mlp_gelu" in cfg.projector_type :
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
downsample_ratio = cfg.get("downsample_ratio", 1)
|
||||
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
|
||||
input_dim = input_dim * downsample_ratio * downsample_ratio
|
||||
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
|
||||
else:
|
||||
fwd = 0
|
||||
|
||||
return fwd * 3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user