T-MoENet / model /adapter.py
yixin1121's picture
Upload folder using huggingface_hub
513e1fb verified
raw
history blame
2.55 kB
import torch.nn as nn
import torch
import math
class Adapter(nn.Module):
def __init__(
self, ds_factor, hidden_dim, ln_after=False, ln_before=False, dropout=0.1
):
super().__init__()
assert not hidden_dim % ds_factor
self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor)
self.act = nn.ReLU()
self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim)
self.apply(self.init_weights)
self.ln_after = ln_after
self.ln_before = ln_before
self.dropout = dropout
if ln_after or ln_before:
self.ln = nn.LayerNorm(hidden_dim)
if dropout:
self.dropout = nn.Dropout(dropout)
def init_weights(self, m: nn.Module, std=1e-3):
if isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, std=std)
torch.nn.init.normal_(m.bias, std=std)
m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
def forward(self, hidden_states):
if self.ln_before:
residual = self.ln(hidden_states)
residual = self.down(residual)
else:
residual = self.down(hidden_states)
residual = self.act(residual)
if self.dropout:
residual = self.dropout(residual)
residual = self.up(residual)
if self.ln_after:
residual = self.ln(hidden_states)
return hidden_states + residual
class ST_Adapter(nn.Module):
def __init__(self, ds_factor, hidden_dim):
super().__init__()
self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor)
self.conv = nn.Conv1d(
hidden_dim // ds_factor, hidden_dim // ds_factor,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_dim // ds_factor
)
self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim)
nn.init.constant_(self.conv.weight, 0.)
nn.init.constant_(self.conv.bias, 0.)
nn.init.constant_(self.down.bias, 0.)
nn.init.constant_(self.up.bias, 0.)
def forward(self, x):
N, T, C = x.size()
ori_x = x
x = self.down(x)
x = x.permute(0, 2, 1).contiguous()
x = self.conv(x)
x = x.permute(0, 2, 1).contiguous()
x = self.up(x)
x = x + ori_x
return x