Source code for cosense3d.modules.utils.misc

from torch import nn


[docs]class SELayer_Linear(nn.Module): def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid, norm=False): super().__init__() self.conv_reduce = nn.Linear(channels, channels) self.act1 = act_layer() self.conv_expand = nn.Linear(channels, channels) self.gate = gate_layer() self.norm = norm
[docs] def forward(self, x, x_se): x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) return x * self.gate(x_se)
[docs]class MLN(nn.Module): ''' Args: c_dim (int): dimension of latent code c f_dim (int): feature dimension ''' def __init__(self, c_dim, f_dim=256): super().__init__() self.c_dim = c_dim self.f_dim = f_dim self.reduce = nn.Sequential( nn.Linear(c_dim, f_dim), nn.ReLU(), ) self.gamma = nn.Linear(f_dim, f_dim) self.beta = nn.Linear(f_dim, f_dim) self.ln = nn.LayerNorm(f_dim, elementwise_affine=False) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.zeros_(self.gamma.weight) nn.init.zeros_(self.beta.weight) nn.init.ones_(self.gamma.bias) nn.init.zeros_(self.beta.bias)
[docs] def forward(self, x, c): x = self.ln(x) c = self.reduce(c) gamma = self.gamma(c) beta = self.beta(c) out = gamma * x + beta return out
[docs]class MLN2(nn.Module): ''' Args: c_dim (int): dimension of latent code c f_dim (int): feature dimension ''' def __init__(self, c_dim, f_dim=256): super().__init__() self.c_dim = c_dim self.f_dim = f_dim self.reduce = nn.Sequential( nn.Linear(c_dim, f_dim), nn.LayerNorm(f_dim), nn.ReLU(), ) self.gamma = nn.Sequential( nn.Linear(f_dim, f_dim), nn.Sigmoid(), ) self.beta = nn.Sequential( nn.Linear(f_dim, f_dim), nn.LayerNorm(f_dim), ) self.ln = nn.LayerNorm(f_dim, elementwise_affine=False) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.zeros_(self.gamma[0].weight) nn.init.zeros_(self.beta[0].weight) nn.init.ones_(self.gamma[0].bias) nn.init.zeros_(self.beta[0].bias)
[docs] def forward(self, x, c): x = self.ln(x) c = self.reduce(c) gamma = self.gamma(c) beta = self.beta(c) out = gamma * x + beta return out